Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8e4b225f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e4b225f
编写于
7月 11, 2018
作者:
视言
提交者:
qingqing01
7月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fake_quantize_op. (#11359)
* Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits.
上级
79d797fd
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
596 addition
and
6 deletion
+596
-6
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+112
-0
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+272
-0
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+155
-0
paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh
...examples/model_inference/multi_thread/convert_protobin.sh
+1
-1
paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh
...examples/model_inference/multi_thread/convert_protobin.sh
+1
-1
paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh
...api/examples/model_inference/sequence/convert_protobin.sh
+1
-1
paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh
...api/examples/model_inference/sequence/convert_protobin.sh
+1
-1
paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh
...xamples/model_inference/sparse_binary/convert_protobin.sh
+1
-1
paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh
...xamples/model_inference/sparse_binary/convert_protobin.sh
+1
-1
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+51
-0
未找到文件。
paddle/fluid/operators/fake_quantize_op.cc
0 → 100644
浏览文件 @
8e4b225f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
class
FakeQuantizeOp
:
public
framework
::
OperatorWithKernel
{
public:
FakeQuantizeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutMovingScale"
),
"OutMovingScale(Out) of FakeQuantizeOp should not be null"
);
// if (ctx->HasInput("InMovingScale")) {
ctx
->
SetOutputDim
(
"OutMovingScale"
,
ctx
->
GetInputDim
(
"InMovingScale"
));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScales"
),
"OutScales(Out) of FakeQuantizeOp should not be null"
);
ctx
->
SetOutputDim
(
"OutScales"
,
ctx
->
GetInputDim
(
"InScales"
));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
class
FakeQuantizeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input tensor of scale operator."
);
AddInput
(
"InScales"
,
"(Tensor) scale buffer, used in static quantization."
)
.
AsDispensable
();
AddInput
(
"InMovingScale"
,
"Last scale, used in static quantization."
)
.
AsDispensable
();
AddInput
(
"InCurrentIter"
,
"Last iteration number, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor."
);
AddOutput
(
"OutScales"
,
"(Tensor) scale buffer, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"OutMovingScale"
,
" Current scale"
);
AddOutput
(
"OutCurrentIter"
,
"Current iteration number."
).
AsDispensable
();
AddAttr
<
std
::
string
>
(
"quantize_type"
,
"(string, default abs_max)"
"The scaling tpe of the quantize operator."
)
.
SetDefault
(
"abs_max"
);
AddAttr
<
int
>
(
"window_size"
,
"(int, default 10000)"
).
SetDefault
(
10000
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
});
AddAttr
<
bool
>
(
"is_test"
,
""
).
SetDefault
(
false
);
AddComment
(
R"DOC(
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(x))$$
quantize_type = range_abs_max:
$$scale = max(max(abs(x)), history_abs_max)$$
quantize_type = moving_average_abs_max:
$$scale = 0.1*scale+0.9*new_abs_max)$$
$$Out = scale*X$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fake_quantize
,
ops
::
FakeQuantizeOp
,
ops
::
FakeQuantizeOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize
,
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/fake_quantize_op.cu
0 → 100644
浏览文件 @
8e4b225f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
shared_max_data
[];
if
(
gridDim
.
x
>
1
)
{
shared_max_data
[
tid
]
=
T
(
0
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
tmp
=
fabs
(
in
[
i
]);
if
(
tmp
>
shared_max_data
[
tid
])
{
shared_max_data
[
tid
]
=
tmp
;
}
}
}
else
{
if
(
bid
<
n
)
{
shared_max_data
[
tid
]
=
fabs
(
in
[
bid
]);
}
else
{
shared_max_data
[
tid
]
=
T
(
0
);
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
])
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
float
FindAbsMaxGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
float
*
array
,
int
length
)
{
float
host_max
;
int
kNumTheads
=
1024
;
int
gridDimx
=
(
kNumTheads
-
1
+
length
)
/
kNumTheads
;
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
framework
::
Tensor
t
;
float
*
device_max
=
t
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
());
FindAbsMaxKernel
<
float
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
ctx
.
stream
()
>>>
(
length
,
array
,
device_max
);
FindAbsMaxKernel
<
float
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
ctx
.
stream
()
>>>
(
gridDimx
,
device_max
,
device_max
);
PADDLE_ENFORCE_EQ
(
cudaMemcpy
(
&
host_max
,
device_max
,
sizeof
(
float
),
cudaMemcpyDeviceToHost
),
cudaSuccess
,
"cudaMemcpy failed"
);
return
host_max
;
}
template
<
typename
T
>
__global__
void
ApplySaturateKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
,
int
*
num_saturate
,
const
T
min
,
const
T
max
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
int
shared_count
[];
shared_count
[
tid
]
=
0
;
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
in
[
i
]
>
max
)
{
out
[
i
]
=
max
;
shared_count
[
tid
]
+=
1
;
}
else
if
(
in
[
i
]
<
min
)
{
out
[
i
]
=
min
;
shared_count
[
tid
]
+=
1
;
}
else
{
out
[
i
]
=
in
[
i
];
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_count
[
tid
]
+=
shared_count
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
num_saturate
[
blockIdx
.
x
]
=
shared_count
[
0
];
}
}
template
<
typename
T
>
__global__
void
ReduceKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
shared_sum
[];
if
(
tid
<
n
)
{
shared_sum
[
tid
]
=
in
[
tid
];
}
else
{
shared_sum
[
tid
]
=
T
(
0
);
}
__syncthreads
();
// blockDim.x must >= n
for
(
int
i
=
(
n
+
1
)
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_sum
[
tid
]
+=
shared_sum
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
0
]
=
shared_sum
[
0
];
}
}
template
<
typename
T
>
int
ApplySaturateGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
int
n
,
const
T
*
in
,
T
*
out
,
const
T
min
,
const
T
max
)
{
int
host_num_saturate
;
int
kNumTheads
=
1024
;
int
gridDimx
=
(
n
+
kNumTheads
-
1
)
/
kNumTheads
;
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
framework
::
Tensor
t
;
int
*
device_num_saturate
=
t
.
mutable_data
<
int
>
(
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
());
ApplySaturateKernel
<
T
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
n
,
in
,
out
,
device_num_saturate
,
min
,
max
);
ReduceKernel
<
int
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
gridDimx
,
device_num_saturate
,
device_num_saturate
);
PADDLE_ENFORCE_EQ
(
cudaSuccess
,
cudaMemcpy
(
&
host_num_saturate
,
device_num_saturate
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
),
"cudaMemcpy failed"
);
return
host_num_saturate
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
T
FindRangeAbsMax
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
,
int
window_size
,
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMaxGpu
(
ctx
,
scale_list
->
data
<
float
>
(),
size
));
}
return
max_scale
;
}
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
)
const
{
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
return
T
(
outs
[
0
]);
}
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
&
device_ctx
=
context
.
cuda_device_context
();
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
)
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
T
scale
=
T
(
1
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
T
bin_cnt
=
(
T
)((
1
<<
(
context
.
Attr
<
int
>
(
"bit_length"
)
-
1
))
-
1
);
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
auto
*
it
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
int
*
last_iter
=
it
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
scale
=
FindRangeAbsMax
(
device_ctx
,
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
}
}
ApplySaturateGpu
<
T
>
(
device_ctx
,
in
->
numel
(),
in
->
data
<
T
>
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
-
scale
,
scale
);
scale
=
bin_cnt
/
scale
;
auto
&
dev
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
eigen_out
.
device
(
dev
)
=
(
scale
*
eigen_in
).
round
();
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
fake_quantize
,
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/fake_quantize_op.h
0 → 100644
浏览文件 @
8e4b225f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
Transform
;
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
T
FindAbsMax
(
framework
::
Tensor
*
in
,
int
n
)
const
{
T
*
p
=
in
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
abs_max
=
(
T
)
0.00000001
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
T
tmp
=
fabs
(
p
[
i
]);
if
(
tmp
>
abs_max
)
abs_max
=
tmp
;
}
return
T
(
abs_max
);
}
T
FindRangeAbsMax
(
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
,
int
window_size
,
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMax
(
scale_list
,
size
));
}
return
max_scale
;
}
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
)
const
{
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
return
T
(
outs
[
0
]);
}
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
*
oms_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
oms_tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
oss_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
oss_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
auto
*
oci_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
oci_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
T
scale
=
static_cast
<
T
>
(
1
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
raw_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
scale_out
(
0
);
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
it
=
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
);
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
const
int
*
last_iter
=
it
->
data
<
int
>
();
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindRangeAbsMax
(
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
}
}
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
in
->
data
<
T
>
(),
in
->
data
<
T
>
()
+
in
->
numel
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
ClipFunctor
<
T
>
(
-
scale
,
scale
));
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
eigen_out
.
device
(
dev
)
=
(
bin_cnt
/
scale
*
eigen_in
).
round
();
}
};
}
// namespace operators
}
// namespace paddle
paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh
已删除
120000 → 0
浏览文件 @
79d797fd
../dense/convert_protobin.sh
\ No newline at end of file
paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh
0 → 100644
浏览文件 @
8e4b225f
../dense/convert_protobin.sh
paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh
已删除
120000 → 0
浏览文件 @
79d797fd
../dense/convert_protobin.sh
\ No newline at end of file
paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh
0 → 100644
浏览文件 @
8e4b225f
../dense/convert_protobin.sh
paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh
已删除
120000 → 0
浏览文件 @
79d797fd
../dense/convert_protobin.sh
\ No newline at end of file
paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh
0 → 100644
浏览文件 @
8e4b225f
../dense/convert_protobin.sh
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
0 → 100644
浏览文件 @
8e4b225f
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize"
self
.
attrs
=
{
'bit_length'
:
8
,
'quantize_type'
:
'abs_max'
,
'window_size'
:
10000
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
),
'InScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'InCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
),
'InMovingScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
self
.
scale
=
{
'abs_max'
:
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
}
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
self
.
scale
[
'abs_max'
]
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'OutMovingScale'
:
np
.
array
([
self
.
scale
[
'abs_max'
]]).
astype
(
"float32"
),
'OutCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录