Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ea72246f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea72246f
编写于
5月 05, 2019
作者:
Z
Zhen Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add MovingAverageAbsMaxScale operator which is only used for calculating the quantization scale.
上级
a72907bb
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
148 addition
and
2 deletion
+148
-2
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+70
-2
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+2
-0
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+42
-0
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+34
-0
未找到文件。
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
ea72246f
...
...
@@ -388,14 +388,76 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker
AddComment
(
R"DOC(
FakeQuantize operator is used in static quantization.
$$scale = (
0.9*max(abs(x))+accum)/(0.9
*state+1)$$
$$range = 2^{bit_length - 1} - 1$$
$$scale = (
moving\_rate*accum+max(abs(x)))/(moving\_rate
*state+1)$$
$$range = 2^{bit
\
_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC"
);
}
};
class
MovingAverageAbsMaxScaleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MovingAverageAbsMaxScaleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of MovingAverageAbsMaxScaleOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale"
),
"Output(OutScale) of MovingAverageAbsMaxScaleOp"
"should not be null"
);
if
(
ctx
->
HasOutput
(
"OutState"
))
{
ctx
->
SetOutputDim
(
"OutState"
,
{
1
});
}
if
(
ctx
->
HasOutput
(
"OutAccum"
))
{
ctx
->
SetOutputDim
(
"OutAccum"
,
{
1
});
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
());
}
};
class
MovingAverageAbsMaxScaleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"InAccum"
,
"Last accum."
).
AsDispensable
();
AddInput
(
"InState"
,
"Last state."
).
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output tensor is just equivalent to the input tensor."
);
AddOutput
(
"OutScale"
,
" Current scale"
);
AddOutput
(
"OutState"
,
"(Tensor) state buffer."
).
AsDispensable
();
AddOutput
(
"OutAccum"
,
"(Tensor) accum buffer."
).
AsDispensable
();
AddAttr
<
float
>
(
"moving_rate"
,
"(float, default 0.9) moving rate."
)
.
SetDefault
(
0.9
);
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set true for inference only and false "
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
MovingAverageAbsMaxScale operator is only used for calculating the quantization scale.
It will not quantize the input tensor.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$Out = X$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -426,3 +488,9 @@ REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleOp
,
ops
::
MovingAverageAbsMaxScaleOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CPU
,
float
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
ea72246f
...
...
@@ -300,3 +300,5 @@ REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_moving_average_abs_max
,
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
ea72246f
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
namespace
paddle
{
...
...
@@ -197,5 +198,46 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
MovingAverageAbsMaxScaleKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
framework
::
TensorCopy
(
*
in
,
context
.
GetPlace
(),
dev_ctx
,
out
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
// testing
if
(
is_test
)
{
return
;
}
// training
auto
*
in_accum
=
context
.
Input
<
framework
::
Tensor
>
(
"InAccum"
);
auto
*
in_state
=
context
.
Input
<
framework
::
Tensor
>
(
"InState"
);
auto
&
allocator
=
platform
::
DeviceTemporaryAllocator
::
Instance
().
Get
(
dev_ctx
);
auto
cur_scale
=
allocator
.
Allocate
(
1
*
sizeof
(
T
));
T
*
cur_scale_data
=
static_cast
<
T
*>
(
cur_scale
->
ptr
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
auto
*
out_state
=
context
.
Output
<
framework
::
Tensor
>
(
"OutState"
);
auto
*
out_accum
=
context
.
Output
<
framework
::
Tensor
>
(
"OutAccum"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
out_state
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_accum
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
moving_rate
=
context
.
Attr
<
float
>
(
"moving_rate"
);
FindMovingAverageAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in_accum
,
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
out_accum
,
out_scale
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
ea72246f
...
...
@@ -130,6 +130,40 @@ class TestFakeQuantizeMovingOp(OpTest):
self
.
check_output
()
class
TestMovingAverageAbsMaxScaleOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"moving_average_abs_max_scale"
self
.
attrs
=
{
'moving_rate'
:
float
(
0.9
),
'is_test'
:
False
}
accum
=
np
.
zeros
(
1
).
astype
(
"float32"
)
accum
[
0
]
=
1
state
=
np
.
zeros
(
1
).
astype
(
"float32"
)
state
[
0
]
=
1
scale
=
np
.
zeros
(
1
).
astype
(
"float32"
)
scale
[
0
]
=
0.001
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
8
,
16
,
7
,
7
)).
astype
(
"float32"
),
'InAccum'
:
accum
,
'InState'
:
state
,
}
out_accum
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_state
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_scale
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_accum
[
0
]
=
self
.
attrs
[
'moving_rate'
]
*
accum
[
0
]
+
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
out_state
[
0
]
=
self
.
attrs
[
'moving_rate'
]
*
state
[
0
]
+
1
out_scale
=
out_accum
/
out_state
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'OutAccum'
:
out_accum
,
'OutState'
:
out_state
,
'OutScale'
:
out_scale
,
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeRangeAbsMaxOp2
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_range_abs_max"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录