Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8965819f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8965819f
编写于
3月 21, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rewrite the cuda kernels of channel_wise_quant_op and channe_wise_dequant_op. test=develop
上级
ec88b6cc
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
405 addition
and
116 deletion
+405
-116
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+43
-0
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+58
-0
paddle/fluid/operators/fake_dequantize_op.h
paddle/fluid/operators/fake_dequantize_op.h
+16
-27
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+45
-0
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+103
-22
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+17
-15
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+0
-2
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+123
-50
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
8965819f
...
...
@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
}
};
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
)
{
if
(
scale_num
==
1
)
{
const
int
channel
=
in
->
dims
()[
0
];
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_factor
[
i
];
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
s
/
max_range
)
*
in_e
;
}
}
else
if
(
scale_num
==
2
)
{
int
batch_size
=
in
->
dims
()[
0
];
int
channel
=
in
->
dims
()[
1
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
framework
::
Tensor
one_batch_in
=
in
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
in
->
dims
(),
1
,
in
->
dims
().
size
()));
framework
::
Tensor
one_batch_out
=
out
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
out
->
dims
(),
1
,
out
->
dims
().
size
()));
for
(
int
j
=
0
;
j
<
channel
;
j
++
)
{
T
s
=
scale_one
[
j
];
framework
::
Tensor
one_channel_in
=
one_batch_in
.
Slice
(
j
,
j
+
1
);
framework
::
Tensor
one_channel_out
=
one_batch_out
.
Slice
(
j
,
j
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
s
*
scale_two
[
0
]
/
max_range
)
*
in_e
;
}
}
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
class
FakeDequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
8965819f
...
...
@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
}
};
template
<
typename
T
>
__global__
void
DequantizeOneScale
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale
[
blockIdx
.
x
]
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
int
batch_size
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
(
batch_size
*
channel
);
int
scale_index
=
blockIdx
.
x
%
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale_one
[
scale_index
]
*
scale_two
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
)
{
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
scale_num
==
1
)
{
int
num
=
in
->
numel
();
int
channel
=
in
->
dims
()[
0
];
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
channel
;
DequantizeOneScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
channel
,
out_data
);
}
else
if
(
scale_num
==
2
)
{
int
num
=
in
->
numel
();
int
batch_size
=
in
->
dims
()[
0
];
int
channel
=
in
->
dims
()[
1
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
batch_size
*
channel
;
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
batch_size
,
channel
,
out_data
);
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/fake_dequantize_op.h
浏览文件 @
8965819f
...
...
@@ -29,6 +29,13 @@ struct DequantizeFunctor {
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelDequantizeFunctor
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeDequantizeMaxAbsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -56,50 +63,32 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
quant_bits
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"quant_bits"
);
int
max_range
=
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
;
int
max_range
=
1
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
dequant
=
DequantizeFunctor
<
DeviceContext
,
T
>
();
if
(
scales
.
size
()
==
1
)
{
int
scale_num
=
scales
.
size
();
if
(
scale_num
==
1
)
{
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
0
],
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only one "
"element."
);
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
scales
[
0
]
->
Slice
(
i
,
i
+
1
);
dequant
(
dev_ctx
,
&
one_channel_in
,
&
one_channel_scale
,
static_cast
<
T
>
(
max_range
),
&
one_channel_out
);
}
}
else
if
(
scales
.
size
()
==
2
)
{
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
);
}
else
if
(
scale_num
==
2
)
{
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
1
],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements."
);
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_batch_in
=
in
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
in
->
dims
(),
1
,
in
->
dims
().
size
()));
framework
::
Tensor
one_batch_out
=
out
->
Slice
(
i
,
i
+
1
).
Resize
(
framework
::
slice_ddim
(
out
->
dims
(),
1
,
out
->
dims
().
size
()));
for
(
int64_t
j
=
0
;
j
<
in
->
dims
()[
1
];
j
++
)
{
framework
::
Tensor
one_channel_in
=
one_batch_in
.
Slice
(
j
,
j
+
1
);
framework
::
Tensor
one_channel_out
=
one_batch_out
.
Slice
(
j
,
j
+
1
);
framework
::
Tensor
one_channel_scale
=
scales
[
0
]
->
Slice
(
j
,
j
+
1
);
dequant
(
dev_ctx
,
&
one_channel_in
,
&
one_channel_scale
,
static_cast
<
T
>
(
max_range
),
&
one_channel_out
);
}
}
PADDLE_ENFORCE_EQ
(
scales
[
1
]
->
numel
(),
1
,
"The second scale tensor should only have one value at now."
);
max_range
=
std
::
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
;
dequant
(
dev_ctx
,
out
,
scales
[
1
],
static_cast
<
T
>
(
max_range
),
out
);
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
)
*
(
std
::
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
);
}
ChannelDequantizeFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scales
.
data
(),
scale_num
,
static_cast
<
T
>
(
max_range
),
out
);
}
};
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
8965819f
...
...
@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
template
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
)
{
const
int
channel_size
=
num
/
channel
;
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
auto
*
start
=
in
+
i
*
channel_size
;
auto
*
end
=
in
+
(
i
+
1
)
*
channel_size
;
out
[
i
]
=
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
}
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
...
...
@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
)
{
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
int
channel_size
=
in
.
numel
()
/
channel
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
ClipFunctor
<
T
>
(
-
s
,
s
));
}
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
/
s
*
out_e
).
round
();
}
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
...
...
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
8965819f
...
...
@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
T
shared_max_data
[];
shared_max_data
[
tid
]
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
fabs
(
in_c
[
i
]);
if
(
tmp
>
shared_max_data
[
tid
])
{
shared_max_data
[
tid
]
=
tmp
;
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
)
{
int
block
=
1024
;
int
grid
=
channel
;
FindChannelAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
channel
,
out
);
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
...
...
@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T
s
=
scale
[
0
];
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
bid
];
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
/
s
*
v
;
out
[
bid
]
=
round
(
v
);
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
/
s
*
v
;
out_c
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
channel
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ChannelClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
channel
,
out_data
);
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
...
...
@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
template
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
...
...
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
8965819f
...
...
@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor {
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindChannelAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
const
int
channel
,
T
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
channel
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
...
...
@@ -86,21 +99,10 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
find_abs_max
=
FindAbsMaxFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel
=
in
->
Slice
(
i
,
i
+
1
);
const
T
*
one_channel_data
=
one_channel
.
data
<
T
>
();
find_abs_max
(
dev_ctx
,
one_channel_data
,
one_channel
.
numel
(),
&
out_scale_data
[
i
]);
}
auto
clip_quant
=
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
out_scale
->
Slice
(
i
,
i
+
1
);
clip_quant
(
dev_ctx
,
one_channel_in
,
one_channel_scale
,
bin_cnt
,
&
one_channel_out
);
}
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
in
->
dims
()[
0
],
out_scale_data
);
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
in
->
dims
()[
0
],
out
);
}
};
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
8965819f
...
...
@@ -576,8 +576,6 @@ class QuantizationFreezePass(object):
elif
self
.
_weight_quantize_type
==
'channel_wise_abs_max'
:
param
=
self
.
_load_var
(
input_arg_name
)
if
len
(
param
.
shape
)
==
4
:
# conv2d or depthwise_conv2d
print
(
'DEBUG**************************: %s'
%
input_arg_name
)
scale_v
=
[]
for
i
in
range
(
param
.
shape
[
0
]):
scale_v
.
append
(
np
.
max
(
np
.
abs
(
param
[
i
])))
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
8965819f
...
...
@@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name
.
endswith
(
'.quantized.dequantized'
))
self
.
assertTrue
(
arg_name
in
quantized_ops
)
def
linear_fc_quant
(
self
,
quant_type
,
for_ci
=
False
):
def
linear_fc_quant
(
self
,
activation_
quant_type
,
for_ci
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
activation_
quant_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
...
...
@@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_linear_fc_quant_abs_max
(
self
):
self
.
linear_fc_quant
(
'abs_max'
,
for_ci
=
True
)
...
...
@@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def
test_linear_fc_quant_moving_average_abs_max
(
self
):
self
.
linear_fc_quant
(
'moving_average_abs_max'
,
for_ci
=
True
)
def
residual_block_quant
(
self
,
quant_type
,
for_ci
=
False
):
def
residual_block_quant
(
self
,
activation_
quant_type
,
for_ci
=
False
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
...
...
@@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
quant_type
)
activation_quantize_type
=
activation_
quant_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
quant_type
,
marked_nodes
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
transform_pass
,
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
...
...
@@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
quant_type
,
val_marked_nodes
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_residual_block_abs_max
(
self
):
self
.
residual_block_quant
(
'abs_max'
,
for_ci
=
True
)
...
...
@@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
class
TestQuantizationFreezePass
(
unittest
.
TestCase
):
def
freeze_graph
(
self
,
use_cuda
,
seed
,
quant_type
,
for_ci
=
False
):
def
freeze_graph
(
self
,
use_cuda
,
seed
,
activation_quant_type
,
weight_quant_type
=
'abs_max'
,
for_ci
=
False
):
def
build_program
(
main
,
startup
,
is_test
):
main
.
random_seed
=
seed
startup
.
random_seed
=
seed
...
...
@@ -245,10 +254,10 @@ class TestQuantizationFreezePass(unittest.TestCase):
transform_pass
=
QuantizationTransformPass
(
scope
=
scope
,
place
=
place
,
activation_quantize_type
=
quant_type
,
weight_quantize_type
=
'channel_wise_abs_max'
)
activation_quantize_type
=
activation_
quant_type
,
weight_quantize_type
=
weight_quant_type
)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=quant_type)
# scope=scope, place=place, activation_quantize_type=
activation_
quant_type)
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
test_graph
)
dev_name
=
'_gpu_'
if
use_cuda
else
'_cpu_'
...
...
@@ -257,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
main_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
quant_type
,
marked_nodes
)
main_graph
.
draw
(
'.'
,
'main'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
marked_nodes
=
set
()
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
memory_optimize
=
False
...
...
@@ -287,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
quant_type
,
loss_v
))
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
loss_v
))
test_data
=
next
(
test_reader
())
with
fluid
.
program_guard
(
quantized_test_program
):
...
...
@@ -302,9 +314,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass
=
QuantizationFreezePass
(
scope
=
scope
,
place
=
place
,
weight_quantize_type
=
'channel_wise_abs_max'
)
scope
=
scope
,
place
=
place
,
weight_quantize_type
=
weight_quant_type
)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass
.
apply
(
test_graph
)
if
not
for_ci
:
...
...
@@ -312,7 +322,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
server_program
=
test_graph
.
to_program
()
...
...
@@ -322,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list
=
[
loss
])
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
quant_type
,
test_loss1
))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
quant_type
,
test_loss2
))
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
test_loss1
))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
test_loss2
))
w_freeze
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0'
).
get_tensor
())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
quant_type
,
np
.
sum
(
w_quant
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_quant
)))
# Convert parameter to 8-bit.
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
scope
,
place
=
place
)
...
...
@@ -343,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
quant_type
,
marked_nodes
)
test_graph
.
draw
(
'.'
,
'test_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
server_program_int8
=
test_graph
.
to_program
()
# Save the 8-bit parameter and model file.
with
fluid
.
scope_guard
(
scope
):
fluid
.
io
.
save_inference_model
(
'server_int8'
+
dev_name
+
quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
server_program_int8
)
fluid
.
io
.
save_inference_model
(
'server_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
server_program_int8
)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[
infer
,
feed
,
fetch
]
=
fluid
.
io
.
load_inference_model
(
'server_int8'
+
dev_name
+
quant_type
,
exe
)
'server_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
exe
)
# Check the loaded 8-bit weight.
w_8bit
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0.int8'
).
get_tensor
())
self
.
assertEqual
(
w_8bit
.
dtype
,
np
.
int8
)
self
.
assertEqual
(
np
.
sum
(
w_8bit
),
np
.
sum
(
w_freeze
))
if
not
for_ci
:
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
quant_type
,
np
.
sum
(
w_8bit
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_8bit
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
np
.
sum
(
w_freeze
)))
mobile_pass
=
TransformForMobilePass
()
mobile_pass
.
apply
(
test_graph
)
...
...
@@ -371,45 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase):
for
op
in
test_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
quant_type
,
test_graph
.
draw
(
'.'
,
'test_mobile'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
marked_nodes
)
mobile_program
=
test_graph
.
to_program
()
with
fluid
.
scope_guard
(
scope
):
fluid
.
io
.
save_inference_model
(
'mobile_int8'
+
dev_name
+
quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
mobile_program
)
fluid
.
io
.
save_inference_model
(
'mobile_int8'
+
dev_name
+
activation_quant_type
+
'_'
+
weight_quant_type
,
[
'image'
,
'label'
],
[
loss
],
exe
,
mobile_program
)
def
test_freeze_graph_cuda_dynamic
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'abs_max'
,
for_ci
=
False
)
True
,
seed
=
1
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cpu_dynamic
(
self
):
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'abs_max'
,
for_ci
=
False
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cuda_static
(
self
):
if
fluid
.
core
.
is_compiled_with_cuda
():
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'range_abs_max'
,
for_ci
=
False
)
True
,
seed
=
1
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
quant_type
=
'moving_average_abs_max'
,
for_ci
=
False
)
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
True
,
seed
=
1
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
def
test_freeze_graph_cpu_static
(
self
):
with
fluid
.
unique_name
.
guard
():
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'range_abs_max'
,
for_ci
=
False
)
False
,
seed
=
2
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
activation_quant_type
=
'range_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
self
.
freeze_graph
(
False
,
seed
=
2
,
quant_type
=
'moving_average_abs_max'
,
for_ci
=
False
)
activation_quant_type
=
'moving_average_abs_max'
,
weight_quant_type
=
'channel_wise_abs_max'
,
for_ci
=
True
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录