Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0353eddb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0353eddb
编写于
8月 28, 2018
作者:
Q
qingqing01
提交者:
GitHub
8月 28, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve fake_dequantize_op. (#12877)
* Improve fake_dequantize_op. * Follow comments.
上级
11e01d9b
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
97 addition
and
32 deletion
+97
-32
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+25
-12
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+36
-0
paddle/fluid/operators/fake_dequantize_op.h
paddle/fluid/operators/fake_dequantize_op.h
+15
-8
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
...n/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
+21
-12
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
0353eddb
...
@@ -18,15 +18,32 @@ limitations under the License. */
...
@@ -18,15 +18,32 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
framework
::
Tensor
*
out
)
{
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
scale_factor
[
0
]
/
max_range
)
*
in_e
;
}
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
class
FakeDequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
class
FakeDequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeDequantizeMaxAbsOp
(
const
std
::
string
&
type
,
FakeDequantizeMaxAbsOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeDequantizeMaxAbsOp should not be null."
);
"Input(X) of FakeDequantizeMaxAbsOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor) The input with float-32/64 type is the "
"(Tensor) The input with float-32/64 type is the "
"low precision tensor."
);
"low precision tensor."
);
AddInput
(
"Scale"
,
"(float) The scale in quantization stage."
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"(Tensor) The output is the dequantized high "
"(Tensor) The output is the dequantized high "
"precision tensor."
);
"precision tensor."
);
AddAttr
<
int
>
(
"num_bits"
,
AddAttr
<
float
>
(
"max_range"
,
"(float) The max range in quantization stage."
);
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8."
);
AddAttr
<
float
>
(
"scale"
,
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
FakeDequantizeMaxAbsOp operator.
FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{
2^{num_bits} - 1
}$$
$$Out = \frac{scale*X}{
max_range
}$$
)DOC"
);
)DOC"
);
}
}
...
...
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
0353eddb
...
@@ -14,6 +14,42 @@ limitations under the License. */
...
@@ -14,6 +14,42 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
KeDequantize
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
T
*
out
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
num
)
{
out
[
idx
]
=
in
[
idx
]
*
scale
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
framework
::
Tensor
*
out
)
{
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
num
=
in
->
numel
();
int
block
=
512
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
KeDequantize
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
out_data
);
}
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
fake_dequantize_max_abs
,
REGISTER_OP_CUDA_KERNEL
(
fake_dequantize_max_abs
,
...
...
paddle/fluid/operators/fake_dequantize_op.h
浏览文件 @
0353eddb
...
@@ -19,22 +19,29 @@ limitations under the License. */
...
@@ -19,22 +19,29 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
struct
DequantizeFunctor
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeDequantizeMaxAbsKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FakeDequantizeMaxAbsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
scale
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Scale"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
in
->
place
());
int
num_bits
=
ctx
.
Attr
<
int
>
(
"num_bits"
);
float
max_range
=
ctx
.
Attr
<
float
>
(
"max_range"
);
T
scale
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"scale"
));
int
range
=
std
::
pow
(
2
,
num_bits
)
-
1
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
DequantizeFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
static_cast
<
T
>
(
max_range
),
out
);
auto
&
dev
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
eigen_out
.
device
(
dev
)
=
(
scale
/
range
)
*
eigen_in
;
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
浏览文件 @
0353eddb
...
@@ -20,41 +20,50 @@ import math
...
@@ -20,41 +20,50 @@ import math
from
op_test
import
OpTest
from
op_test
import
OpTest
def
quantize_max_abs
(
x
,
num_bits
):
def
quantize_max_abs
(
x
,
max_range
):
range
=
math
.
pow
(
2
,
num_bits
)
-
1
scale
=
np
.
max
(
np
.
abs
(
x
).
flatten
())
scale
=
np
.
max
(
np
.
abs
(
x
).
flatten
())
y
=
np
.
round
(
x
/
scale
*
range
)
y
=
np
.
round
(
x
/
scale
*
max_
range
)
return
y
,
scale
return
y
,
scale
def
dequantize_max_abs
(
x
,
num_bits
,
scale
):
def
dequantize_max_abs
(
x
,
scale
,
max_range
):
range
=
math
.
pow
(
2
,
num_bits
)
-
1
y
=
(
scale
/
max_range
)
*
x
y
=
(
scale
/
range
)
*
x
return
y
return
y
class
TestFakeDequantizeMaxAbsOp
(
OpTest
):
class
TestFakeDequantizeMaxAbsOp
(
OpTest
):
def
set_args
(
self
):
def
set_args
(
self
):
self
.
num_bits
=
8
self
.
num_bits
=
8
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float32"
def
setUp
(
self
):
def
setUp
(
self
):
self
.
set_args
()
self
.
set_args
()
self
.
op_type
=
"fake_dequantize_max_abs"
self
.
op_type
=
"fake_dequantize_max_abs"
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
"float32"
)
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
data_type
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
num_bits
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
max_range
)
ydq
=
dequantize_max_abs
(
yq
,
s
elf
.
num_bits
,
scal
e
)
ydq
=
dequantize_max_abs
(
yq
,
s
cale
,
self
.
max_rang
e
)
self
.
inputs
=
{
'X'
:
yq
}
self
.
inputs
=
{
'X'
:
yq
,
'Scale'
:
np
.
array
(
scale
).
astype
(
self
.
data_type
)
}
self
.
attrs
=
{
'
num_bits'
:
self
.
num_bits
,
'scale'
:
float
(
scale
)
}
self
.
attrs
=
{
'
max_range'
:
self
.
max_range
}
self
.
outputs
=
{
'Out'
:
ydq
}
self
.
outputs
=
{
'Out'
:
ydq
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
class
TestFakeDequantizeMaxAbsOp5Bits
(
OpTest
):
class
TestFakeDequantizeMaxAbsOpDouble
(
TestFakeDequantizeMaxAbsOp
):
def
set_args
(
self
):
self
.
num_bits
=
8
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float64"
class
TestFakeDequantizeMaxAbsOp5Bits
(
TestFakeDequantizeMaxAbsOp
):
def
set_args
(
self
):
def
set_args
(
self
):
self
.
num_bits
=
5
self
.
num_bits
=
5
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float32"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录