Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ec11135d
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ec11135d
编写于
3月 22, 2019
作者:
Z
Zhen Wang
提交者:
GitHub
3月 22, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16341 from wzzju/add_channel_wise_in_quant_pass
Add channel wise in quant pass.
上级
e235882c
8965819f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
655 addition
and
146 deletion
+655
-146
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+43
-0
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+58
-0
paddle/fluid/operators/fake_dequantize_op.h
paddle/fluid/operators/fake_dequantize_op.h
+27
-18
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+49
-4
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+103
-22
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+19
-17
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+194
-27
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+129
-44
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
...n/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
+32
-13
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+1
-1
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
ec11135d
...
...
@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
}
};
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
)
{
if
(
scale_num
==
1
)
{
const
int
channel
=
in
->
dims
()[
0
];
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_factor
[
i
];
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
s
/
max_range
)
*
in_e
;
}
}
else
if
(
scale_num
==
2
)
{
int
batch_size
=
in
->
dims
()[
0
];
int
channel
=
in
->
dims
()[
1
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
framework
::
Tensor
one_batch_in
=
in
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
in
->
dims
(),
1
,
in
->
dims
().
size
()));
framework
::
Tensor
one_batch_out
=
out
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
out
->
dims
(),
1
,
out
->
dims
().
size
()));
for
(
int
j
=
0
;
j
<
channel
;
j
++
)
{
T
s
=
scale_one
[
j
];
framework
::
Tensor
one_channel_in
=
one_batch_in
.
Slice
(
j
,
j
+
1
);
framework
::
Tensor
one_channel_out
=
one_batch_out
.
Slice
(
j
,
j
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
s
*
scale_two
[
0
]
/
max_range
)
*
in_e
;
}
}
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
class
FakeDequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
ec11135d
...
...
@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
}
};
template
<
typename
T
>
__global__
void
DequantizeOneScale
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale
[
blockIdx
.
x
]
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
int
batch_size
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
(
batch_size
*
channel
);
int
scale_index
=
blockIdx
.
x
%
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale_one
[
scale_index
]
*
scale_two
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
)
{
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
scale_num
==
1
)
{
int
num
=
in
->
numel
();
int
channel
=
in
->
dims
()[
0
];
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
channel
;
DequantizeOneScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
channel
,
out_data
);
}
else
if
(
scale_num
==
2
)
{
int
num
=
in
->
numel
();
int
batch_size
=
in
->
dims
()[
0
];
int
channel
=
in
->
dims
()[
1
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
batch_size
*
channel
;
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
batch_size
,
channel
,
out_data
);
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/fake_dequantize_op.h
浏览文件 @
ec11135d
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -28,6 +29,13 @@ struct DequantizeFunctor {
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelDequantizeFunctor
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeDequantizeMaxAbsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto
scales
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Scales"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
0
],
"The number of first scale values must be the same with "
"first dimension value of Input(X)."
);
auto
quant_bits
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"quant_bits"
);
int
max_range
=
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
;
int
max_range
=
1
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
dequant
=
DequantizeFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
scales
[
0
]
->
Slice
(
i
,
i
+
1
);
dequant
(
dev_ctx
,
&
one_channel_in
,
&
one_channel_scale
,
static_cast
<
T
>
(
max_range
),
&
one_channel_out
);
}
if
(
scales
.
size
()
==
2
)
{
int
scale_num
=
scales
.
size
();
if
(
scale_num
==
1
)
{
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
0
],
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only one "
"element."
);
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
);
}
else
if
(
scale_num
==
2
)
{
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
1
],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements."
);
PADDLE_ENFORCE_EQ
(
scales
[
1
]
->
numel
(),
1
,
"The second scale tensor should only have one value at now."
);
max_range
=
std
::
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
;
dequant
(
dev_ctx
,
out
,
scales
[
1
],
static_cast
<
T
>
(
max_range
),
out
);
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
)
*
(
std
::
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
);
}
ChannelDequantizeFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scales
.
data
(),
scale_num
,
static_cast
<
T
>
(
max_range
),
out
);
}
};
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
ec11135d
...
...
@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
template
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
)
{
const
int
channel_size
=
num
/
channel
;
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
auto
*
start
=
in
+
i
*
channel_size
;
auto
*
end
=
in
+
(
i
+
1
)
*
channel_size
;
out
[
i
]
=
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
}
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
...
...
@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
)
{
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
int
channel_size
=
in
.
numel
()
/
channel
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
ClipFunctor
<
T
>
(
-
s
,
s
));
}
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
/
s
*
out_e
).
round
();
}
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
...
...
@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale
s
"
),
"Output(Scale
s
) of FakeChannelWiseQuantizeOp should not be null."
);
ctx
->
HasOutput
(
"OutScale"
),
"Output(Scale) of FakeChannelWiseQuantizeOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale
s
"
,
{
ctx
->
GetInputDim
(
"X"
)[
0
]});
ctx
->
SetOutputDim
(
"OutScale"
,
{
ctx
->
GetInputDim
(
"X"
)[
0
]});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
...
...
@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."
);
AddOutput
(
"OutScale
s
"
,
"(Tensor) Current channel wise scale"
);
AddOutput
(
"OutScale"
,
"(Tensor) Current channel wise scale"
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
...
...
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
ec11135d
...
...
@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
T
shared_max_data
[];
shared_max_data
[
tid
]
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
fabs
(
in_c
[
i
]);
if
(
tmp
>
shared_max_data
[
tid
])
{
shared_max_data
[
tid
]
=
tmp
;
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
)
{
int
block
=
1024
;
int
grid
=
channel
;
FindChannelAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
channel
,
out
);
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
...
...
@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T
s
=
scale
[
0
];
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
bid
];
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
/
s
*
v
;
out
[
bid
]
=
round
(
v
);
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
/
s
*
v
;
out_c
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
channel
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ChannelClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
channel
,
out_data
);
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
...
...
@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
template
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
ec11135d
...
...
@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor {
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindChannelAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
...
...
@@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scale
s
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales
"
);
T
*
out_scale
s_data
=
out_scales
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale
"
);
T
*
out_scale
_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
find_abs_max
=
FindAbsMaxFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel
=
in
->
Slice
(
i
,
i
+
1
);
const
T
*
one_channel_data
=
one_channel
.
data
<
T
>
();
find_abs_max
(
dev_ctx
,
one_channel_data
,
one_channel
.
numel
(),
&
out_scales_data
[
i
]);
}
auto
clip_quant
=
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
out_scales
->
Slice
(
i
,
i
+
1
);
clip_quant
(
dev_ctx
,
one_channel_in
,
one_channel_scale
,
bin_cnt
,
&
one_channel_out
);
}
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
in
->
dims
()[
0
],
out_scale_data
);
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
in
->
dims
()[
0
],
out
);
}
};
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
ec11135d
...
...
@@ -22,6 +22,7 @@ from ....framework import IrGraph
from
....framework
import
IrNode
from
....framework
import
Program
from
....initializer
import
Constant
from
....initializer
import
NumpyArrayInitializer
from
....
import
unique_name
__all__
=
[
...
...
@@ -54,14 +55,15 @@ class QuantizationTransformPass(object):
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'
. If use 'abs_max' mode,
the quantization scale will be calculated dynamically each step
in both training and testing period. If use 'range_abs_max',
a static quantization scale will be calculated during training
and used in inference.
now support 'abs_max', 'range_abs_max'
and 'moving_average_abs_max'.
If use 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training
and used in inference.
weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
usually is not used for weight, since weights are fixed once the
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
...
...
@@ -84,7 +86,11 @@ class QuantizationTransformPass(object):
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
quant_type
=
[
'abs_max'
,
'range_abs_max'
,
'moving_average_abs_max'
]
quant_type
=
[
'abs_max'
,
'channel_wise_abs_max'
,
'range_abs_max'
,
'moving_average_abs_max'
]
assert
activation_quantize_type
!=
'channel_wise_abs_max'
,
"The activation quantization type does not support 'channel_wise_abs_max'."
if
activation_quantize_type
not
in
quant_type
:
raise
ValueError
(
"Unknown activation_quantize_type : '%s'. It can only be "
,
...
...
@@ -93,7 +99,7 @@ class QuantizationTransformPass(object):
if
weight_quantize_type
not
in
quant_type
:
raise
ValueError
(
"Unknown weight_quantize_type: '%s'. It can only be "
,
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
,
"'abs_max' or '
channel_wise_abs_max' or '
range_abs_max' or 'moving_average_abs_max'."
,
str
(
weight_quantize_type
))
self
.
_activation_quantize_type
=
activation_quantize_type
...
...
@@ -103,6 +109,7 @@ class QuantizationTransformPass(object):
self
.
_need_initialized
=
collections
.
OrderedDict
()
self
.
_quantizable_ops
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_quantizable_grad_ops
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
_quantizable_ops
]
...
...
@@ -135,10 +142,26 @@ class QuantizationTransformPass(object):
else
self
.
_activation_bits
quant_type
=
self
.
_weight_quantize_type
if
var_node
.
name
()
\
in
persistable_vars
else
self
.
_activation_quantize_type
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
graph
,
var_node
,
quant_bits
,
quant_type
)
dequant_var_node
=
self
.
_insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
)
if
quant_type
==
'channel_wise_abs_max'
:
assert
var_node
.
name
(
)
in
persistable_vars
,
"'channel_wise_abs_max' can only be applied on weights."
if
op
.
name
()
in
self
.
_conv_ops
:
quant_var_node
,
scale_var_node
=
self
.
_insert_channel_quant_op
(
graph
,
var_node
,
quant_bits
)
dequant_var_node
=
self
.
_insert_channel_dequant_op
(
graph
,
quant_var_node
,
[
scale_var_node
],
[
quant_bits
])
else
:
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
graph
,
var_node
,
quant_bits
,
'abs_max'
)
dequant_var_node
=
self
.
_insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
)
else
:
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
graph
,
var_node
,
quant_bits
,
quant_type
)
dequant_var_node
=
self
.
_insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
)
dequantized_vars
[
var_node
.
name
()]
=
dequant_var_node
graph
.
update_input_link
(
var_node
,
dequant_var_node
,
op
)
...
...
@@ -244,7 +267,7 @@ class QuantizationTransformPass(object):
scale_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
shape
()
,
shape
=
[
1
]
,
var_dtype
=
var_node
.
dtype
())
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
...
...
@@ -384,6 +407,36 @@ class QuantizationTransformPass(object):
return
quant_var_node
,
scale_out_node
def
_insert_channel_quant_op
(
self
,
graph
,
var_node
,
quant_bits
):
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
"""
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
quant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
dtype
())
scale_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
var_node
.
type
(),
shape
=
[
var_node
.
shape
()[
0
]],
var_dtype
=
var_node
.
dtype
())
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_quantize_abs_max'
,
attrs
=
{
'bit_length'
:
quant_bits
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
inputs
=
{
'X'
:
var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
})
graph
.
link_to
(
var_node
,
quant_op_node
)
graph
.
link_to
(
quant_op_node
,
quant_var_node
)
graph
.
link_to
(
quant_op_node
,
scale_var_node
)
return
quant_var_node
,
scale_var_node
def
_insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
,
quant_bits
):
"""
Insert fake_dequantize_op in the graph.
...
...
@@ -410,6 +463,33 @@ class QuantizationTransformPass(object):
graph
.
link_to
(
dequant_op_node
,
dequant_var_node
)
return
dequant_var_node
def
_insert_channel_dequant_op
(
self
,
graph
,
var_node
,
scale_var_nodes
,
quant_bits
):
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
"""
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
type
(),
shape
=
var_node
.
shape
(),
var_dtype
=
var_node
.
dtype
())
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
attrs
=
{
'quant_bits'
:
quant_bits
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
inputs
=
{
'X'
:
var_node
,
'Scales'
:
scale_var_nodes
},
outputs
=
{
'Out'
:
dequant_var_node
})
graph
.
link_to
(
var_node
,
dequant_op_node
)
for
scale_n
in
scale_var_nodes
:
graph
.
link_to
(
scale_n
,
dequant_op_node
)
graph
.
link_to
(
dequant_op_node
,
dequant_var_node
)
return
dequant_var_node
def
_quantized_var_name
(
self
,
var_name
):
"""
Return quantized variable name for the input `var_name`.
...
...
@@ -442,7 +522,7 @@ class QuantizationFreezePass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'.
weight_quantize_type (str): quantization type for weights, support 'abs_max'
and 'channel_wise_abs_max'
.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained.
"""
...
...
@@ -463,11 +543,15 @@ class QuantizationFreezePass(object):
self
.
_activation_bits
=
activation_bits
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_quantizable_ops
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
self
.
_conv_ops
=
[
'conv2d'
,
'depthwise_conv2d'
]
self
.
_fake_quant_op_names
=
[
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
,
'fake_quantize_moving_average_abs_max'
'fake_quantize_moving_average_abs_max'
,
'fake_channel_wise_quantize_abs_max'
]
self
.
_fake_dequant_op_names
=
[
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
]
self
.
_fake_dequant_op_names
=
[
'fake_dequantize_max_abs'
]
self
.
_op_input_rename_map
=
collections
.
OrderedDict
()
self
.
_op_output_rename_map
=
collections
.
OrderedDict
()
self
.
_var_scale_map
=
collections
.
OrderedDict
()
...
...
@@ -489,20 +573,27 @@ class QuantizationFreezePass(object):
if
self
.
_weight_quantize_type
==
'abs_max'
:
param
=
self
.
_load_var
(
input_arg_name
)
scale_v
=
np
.
max
(
np
.
abs
(
param
))
elif
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
param
=
self
.
_load_var
(
input_arg_name
)
if
len
(
param
.
shape
)
==
4
:
# conv2d or depthwise_conv2d
scale_v
=
[]
for
i
in
range
(
param
.
shape
[
0
]):
scale_v
.
append
(
np
.
max
(
np
.
abs
(
param
[
i
])))
else
:
scale_v
=
np
.
max
(
np
.
abs
(
param
))
else
:
scale_v
=
self
.
_load_var
(
op_node
.
output
(
'OutScale'
)[
0
])[
0
]
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
else
:
scale_v
=
graph
.
var_node
(
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
if
input_arg_name
in
persistable_vars
:
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
# quantize weight and restore
param_v
=
self
.
_load_var
(
input_arg_name
)
quantized_param_v
=
self
.
_quant
(
param_v
,
scale_v
,
self
.
_weight_bits
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
else
:
scale_v
=
graph
.
var_node
(
op_node
.
output
(
'OutScale'
)[
0
])
self
.
_var_scale_map
[
input_arg_name
]
=
scale_v
ops
=
graph
.
all_op_nodes
()
for
op_node
in
ops
:
...
...
@@ -514,7 +605,10 @@ class QuantizationFreezePass(object):
for
op_node
in
ops
:
op_name
=
op_node
.
name
()
if
op_name
in
self
.
_quantizable_ops
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
if
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
and
op_name
in
self
.
_conv_ops
:
self
.
_insert_post_channel_dequant_op
(
graph
,
op_node
)
else
:
self
.
_insert_post_dequant_op
(
graph
,
op_node
)
for
op_node
in
ops
:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
...
...
@@ -538,9 +632,73 @@ class QuantizationFreezePass(object):
self
.
_op_input_rename_map
[
k
]
=
self
.
_op_input_rename_map
[
v
]
graph
.
safe_remove_nodes
(
op_node
)
def
_insert_post_channel_dequant_op
(
self
,
graph
,
op_node
):
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
if
name
in
self
.
_op_input_rename_map
:
old_in
=
graph
.
var_node
(
name
)
new_in
=
graph
.
var_node
(
self
.
_op_input_rename_map
[
name
])
new_in
.
clear_outputs
()
graph
.
update_input_link
(
old_in
,
new_in
,
op_node
)
original_var_name
=
self
.
_original_var_name
(
name
)
scale_v
=
self
.
_var_scale_map
[
original_var_name
]
if
original_var_name
in
persistable_vars
:
assert
isinstance
(
scale_v
,
list
),
'The scale of parameter %s is not a list.'
%
(
original_var_name
)
channel_scale
=
np
.
array
(
scale_v
)
else
:
assert
isinstance
(
scale_v
,
IrNode
)
scale_var_node
=
self
.
_var_scale_map
[
original_var_name
]
if
len
(
op_node
.
outputs
)
!=
1
:
raise
ValueError
(
"Only support one output, but op %s has"
" more than one output."
%
(
op_node
.
name
()))
output_var_node
=
op_node
.
outputs
[
0
]
weight_scale_node
=
graph
.
create_persistable_node
(
name
=
unique_name
.
generate
(
'channel_scale'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
channel_scale
.
shape
[
0
]],
var_dtype
=
output_var_node
.
dtype
())
init_program
=
Program
()
weight_scale_var
=
init_program
.
global_block
().
create_var
(
name
=
weight_scale_node
.
name
(),
shape
=
weight_scale_node
.
shape
(),
dtype
=
weight_scale_node
.
dtype
(),
type
=
weight_scale_node
.
type
(),
lod_level
=
weight_scale_node
.
var
().
lod_level
(),
persistable
=
weight_scale_node
.
persistable
())
initializer
=
NumpyArrayInitializer
(
value
=
channel_scale
)
initializer
(
weight_scale_var
,
init_program
.
global_block
())
exe
=
Executor
(
self
.
_place
)
exe
.
run
(
program
=
init_program
,
scope
=
self
.
_scope
)
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
output_var_node
.
name
()),
var_type
=
output_var_node
.
type
(),
shape
=
output_var_node
.
shape
(),
var_dtype
=
output_var_node
.
dtype
())
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
attrs
=
{
'quant_bits'
:
[
self
.
_weight_bits
,
self
.
_activation_bits
],
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
inputs
=
{
'X'
:
output_var_node
,
'Scales'
:
[
weight_scale_node
,
scale_var_node
]
},
outputs
=
{
'Out'
:
dequant_var_node
})
graph
.
link_to
(
output_var_node
,
dequant_op_node
)
graph
.
link_to
(
scale_var_node
,
dequant_op_node
)
graph
.
link_to
(
weight_scale_node
,
dequant_op_node
)
graph
.
link_to
(
dequant_op_node
,
dequant_var_node
)
self
.
_op_output_rename_map
[
output_var_node
.
name
()]
=
dequant_var_node
return
dequant_var_node
def
_insert_post_dequant_op
(
self
,
graph
,
op_node
):
max_range
=
None
scale_var_node
=
None
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_nodes
()]
for
var_node
in
op_node
.
inputs
:
name
=
var_node
.
name
()
...
...
@@ -637,7 +795,12 @@ class QuantizationFreezePass(object):
or
isinstance
(
v
,
np
.
float64
)
def
_quant
(
self
,
x
,
scale
,
num_bits
):
return
np
.
round
(
x
/
scale
*
((
1
<<
(
num_bits
-
1
))
-
1
))
if
isinstance
(
scale
,
list
):
for
i
,
s
in
enumerate
(
scale
):
x
[
i
]
=
np
.
round
(
x
[
i
]
/
s
*
((
1
<<
(
num_bits
-
1
))
-
1
))
return
x
else
:
return
np
.
round
(
x
/
scale
*
((
1
<<
(
num_bits
-
1
))
-
1
))
class
ConvertToInt8Pass
(
object
):
...
...
@@ -731,9 +894,13 @@ class TransformForMobilePass(object):
def
__init__
(
self
):
self
.
_fake_quant_op_names
=
[
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
,
'fake_quantize_moving_average_abs_max'
,
'fake_channel_wise_quantize_abs_max'
]
self
.
_fake_dequant_op_names
=
[
'fake_dequantize_max_abs'
,
'fake_channel_wise_dequantize_max_abs'
]
self
.
_fake_dequant_op_names
=
[
'fake_dequantize_max_abs'
]
def
apply
(
self
,
graph
):
"""
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
ec11135d
...
...
@@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name
.
endswith
(
'.quantized.dequantized'
))
self
.
assertTrue
(
arg_name
in
quantized_ops
)
def
linear_fc_quant
(
self
,
quant_type
,
for_ci
=
False
):
def
linear_fc_quant
(
self
,
activation_
quant_type
,
for_ci
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
activation_
quant_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
...
...
@@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_linear_fc_quant_abs_max
(
self
):
self
.
linear_fc_quant
(
'abs_max'
,
for_ci
=
True
)
...
...
@@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def
test_linear_fc_quant_moving_average_abs_max
(
self
):
self
.
linear_fc_quant
(
'moving_average_abs_max'
,
for_ci
=
True
)
def
residual_block_quant
(
self
,
quant_type
,
for_ci
=
False
):
def
residual_block_quant
(
self
,
activation_
quant_type
,
for_ci
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
activation_
quant_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
...
...
@@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_residual_block_abs_max
(
self
):
self
.
residual_block_quant
(
'abs_max'
,
for_ci
=
True
)
...
...
@@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
class
TestQuantizationFreezePass
(
unittest
.
TestCase
):
def
freeze_graph
(
self
,
use_cuda
,
seed
,
quant_type
,
for_ci
=
False
):
def
freeze_graph
(
self
,
use_cuda
,
seed
,
activation_quant_type
,
weight_quant_type
=
'abs_max'
,
for_ci
=
False
):
def
build_program
(
main
,
startup
,
is_test
):
main
.
random_seed
=
seed
startup
.
random_seed
=
seed
...
...
@@ -243,7 +252,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup
)
transform_pass
=
QuantizationTransformPass
(
scope
=
scope
,
place
=
place
,
activation_quantize_type
=
quant_type
)
scope
=
scope
,
place
=
place
,
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quant_type
)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
test_graph
)
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
...
...
@@ -252,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
main_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
quant_type
,
marked_nodes
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
marked_nodes
=
set
()
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
False
...
...
@@ -282,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
quant_type
,
loss_v
))
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
loss_v
))
test_data
=
next
(
test_reader
())
with
fluid
.
program_guard
(
quantized_test_program
):
...
...
@@ -296,14 +313,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list
=
[
loss
,
w_var
])
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
)
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
,
weight_quantize_type
=
weight_quant_type
)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass
.
apply
(
test_graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
server_program
=
test_graph
.
to_program
()
...
...
@@ -313,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list
=
[
loss
])
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
quant_type
,
test_loss1
))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
quant_type
,
test_loss2
))
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
test_loss1
))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
test_loss2
))
w_freeze
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0'
).
get_tensor
())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
quant_type
,
np
.
sum
(
w_quant
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_quant
)))
# Convert parameter to 8-bit.
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
scope
,
place
=
place
)
...
...
@@ -334,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
server_program_int8
=
test_graph
.
to_program
()
# Save the 8-bit parameter and model file.
with
fluid
.
scope_guard
(
scope
):
fluid
.
io
.
save_inference_model
(
'server_int8'
+
dev_name
+
quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
server_program_int8
)
fluid
.
io
.
save_inference_model
(
'server_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
server_program_int8
)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[
infer
,
feed
,
fetch
]
=
fluid
.
io
.
load_inference_model
(
'server_int8'
+
dev_name
+
quant_type
,
exe
)
'server_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
exe
)
# Check the loaded 8-bit weight.
w_8bit
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0.int8'
).
get_tensor
())
self
.
assertEqual
(
w_8bit
.
dtype
,
np
.
int8
)
self
.
assertEqual
(
np
.
sum
(
w_8bit
),
np
.
sum
(
w_freeze
))
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
quant_type
,
np
.
sum
(
w_8bit
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_8bit
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_freeze
)))
mobile_pass
=
TransformForMobilePass
()
mobile_pass
.
apply
(
test_graph
)
...
...
@@ -362,42 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
mobile_program
=
test_graph
.
to_program
()
with
fluid
.
scope_guard
(
scope
):
fluid
.
io
.
save_inference_model
(
'mobile_int8'
+
dev_name
+
quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
mobile_program
)
fluid
.
io
.
save_inference_model
(
'mobile_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
mobile_program
)
def
test_freeze_graph_cuda_dynamic
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'abs_max'
,
for_ci
=
True
)
True
,
seed
=
1
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cpu_dynamic
(
self
):
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cuda_static
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'range_abs_max'
,
for_ci
=
True
)
True
,
seed
=
1
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'moving_average_abs_max'
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cpu_static
(
self
):
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'range_abs_max'
,
for_ci
=
True
)
False
,
seed
=
2
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'moving_average_abs_max'
,
for_ci
=
True
)
False
,
seed
=
2
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
浏览文件 @
ec11135d
...
...
@@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range):
return
y
def
channel_wise_quantize_max_abs
(
x
,
quant_bit
=
8
):
def
channel_wise_quantize_max_abs
(
x
,
quant_bit
=
8
,
use_second_dim
=
False
):
scales
=
[]
for
i
in
range
(
x
.
shape
[
0
]):
scales
.
append
(
np
.
max
(
np
.
abs
(
x
[
i
])).
astype
(
"float32"
))
y
=
x
.
copy
()
max_range
=
math
.
pow
(
2
,
quant_bit
-
1
)
-
1
for
i
,
scale
in
enumerate
(
scales
):
y
[
i
]
=
np
.
round
(
y
[
i
]
/
scale
*
max_range
)
if
not
use_second_dim
:
for
i
in
range
(
x
.
shape
[
0
]):
scales
.
append
(
np
.
max
(
np
.
abs
(
x
[
i
])).
astype
(
"float32"
))
y
=
x
.
copy
()
max_range
=
math
.
pow
(
2
,
quant_bit
-
1
)
-
1
for
i
,
scale
in
enumerate
(
scales
):
y
[
i
]
=
np
.
round
(
x
[
i
]
/
scale
*
max_range
)
else
:
for
i
in
range
(
x
.
shape
[
0
]):
s
=
[]
for
j
in
range
(
x
.
shape
[
1
]):
s
.
append
(
np
.
max
(
np
.
abs
(
x
[
i
][
j
])).
astype
(
"float32"
))
scales
.
append
(
s
)
scales
=
np
.
amax
(
np
.
array
(
scales
),
axis
=
0
)
y
=
x
.
copy
()
max_range
=
math
.
pow
(
2
,
quant_bit
-
1
)
-
1
for
i
in
range
(
x
.
shape
[
0
]):
for
j
,
scale
in
enumerate
(
scales
):
y
[
i
][
j
]
=
np
.
round
(
x
[
i
][
j
]
/
scale
*
max_range
)
return
y
,
scales
...
...
@@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x,
scales
,
quant_bits
,
activation_scale
=
None
):
y
=
x
.
copy
()
for
i
in
range
(
x
.
shape
[
0
]):
y
[
i
]
=
(
scales
[
i
]
/
(
math
.
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
))
*
y
[
i
]
if
activation_scale
is
not
None
:
if
activation_scale
is
None
:
y
=
x
.
copy
()
for
i
in
range
(
x
.
shape
[
0
]):
y
[
i
]
=
(
scales
[
i
]
/
(
math
.
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
))
*
x
[
i
]
else
:
y
=
x
.
copy
()
for
i
in
range
(
x
.
shape
[
0
]):
for
j
in
range
(
x
.
shape
[
1
]):
y
[
i
][
j
]
=
(
scales
[
j
]
/
(
math
.
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
))
*
x
[
i
][
j
]
y
*=
activation_scale
/
(
math
.
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
)
return
y
...
...
@@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self
.
set_args
()
self
.
op_type
=
"fake_channel_wise_dequantize_max_abs"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
])
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
],
use_second_dim
=
True
)
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scales
,
self
.
quant_bits
,
self
.
activation_scale
)
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
ec11135d
...
...
@@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
self
.
outputs
=
{
'Out'
:
outputs
,
'OutScale
s
'
:
np
.
array
(
scales
).
astype
(
"float32"
),
'OutScale'
:
np
.
array
(
scales
).
astype
(
"float32"
),
}
def
test_check_output
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录