Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b950d3a3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b950d3a3
编写于
6月 20, 2019
作者:
X
xingzhaolong
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'xzl/incubate/lite' into 'incubate/lite'
init . ARM INT8 support See merge request inference/paddlelite!33
上级
ad333ac5
4fe5c8aa
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
550 addition
and
23 deletion
+550
-23
paddle/fluid/lite/api/cxx_api_bin.cc
paddle/fluid/lite/api/cxx_api_bin.cc
+2
-0
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+5
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+1
-2
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
...te/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
...re/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
+1
-1
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+45
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
+33
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
+174
-0
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h
+58
-0
paddle/fluid/lite/core/mir/pattern_matcher.cc
paddle/fluid/lite/core/mir/pattern_matcher.cc
+1
-2
paddle/fluid/lite/core/mir/use_passes.h
paddle/fluid/lite/core/mir/use_passes.h
+3
-4
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+3
-4
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+2
-2
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+2
-2
paddle/fluid/lite/kernels/arm/conv_compute.cc
paddle/fluid/lite/kernels/arm/conv_compute.cc
+2
-2
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+4
-0
paddle/fluid/lite/operators/fake_dequantize_max_abs.cc
paddle/fluid/lite/operators/fake_dequantize_max_abs.cc
+25
-0
paddle/fluid/lite/operators/fake_dequantize_max_abs.h
paddle/fluid/lite/operators/fake_dequantize_max_abs.h
+64
-0
paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc
.../fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc
+25
-0
paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h
...e/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h
+69
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+22
-0
paddle/fluid/lite/operators/softmax_op.cc
paddle/fluid/lite/operators/softmax_op.cc
+6
-1
未找到文件。
paddle/fluid/lite/api/cxx_api_bin.cc
浏览文件 @
b950d3a3
...
@@ -88,6 +88,8 @@ USE_LITE_OP(depthwise_conv2d);
...
@@ -88,6 +88,8 @@ USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP
(
pool2d
);
USE_LITE_OP
(
pool2d
);
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
softmax
);
USE_LITE_OP
(
softmax
);
USE_LITE_OP
(
fake_quantize_moving_average_abs_max
);
USE_LITE_OP
(
fake_dequantize_max_abs
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
b950d3a3
...
@@ -13,6 +13,7 @@ cc_library(mir_passes
...
@@ -13,6 +13,7 @@ cc_library(mir_passes
fusion/conv_elementwise_add_activation_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
variable_place_inference_pass.cc
...
...
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
b950d3a3
...
@@ -10,11 +10,15 @@ cc_library(fuse_conv_bn
...
@@ -10,11 +10,15 @@ cc_library(fuse_conv_bn
cc_library
(
fuse_elementwise_add_activation
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
set
(
mir_fusers
fuse_fc
fuse_fc
fuse_conv_elementwise_add_activation
fuse_conv_elementwise_add_activation
fuse_conv_bn
fuse_conv_bn
fuse_quant_dequant
fuse_elementwise_add_activation
fuse_elementwise_add_activation
CACHE INTERNAL
"fusers"
)
CACHE INTERNAL
"fusers"
)
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
浏览文件 @
b950d3a3
...
@@ -82,7 +82,7 @@ cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
...
@@ -82,7 +82,7 @@ cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
const
key2nodes_t
&
matched
)
{
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
=
*
desc
;
op_desc
.
SetType
(
conv_type_
);
op_desc
.
SetType
(
conv_type_
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
...
@@ -95,7 +95,6 @@ cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
...
@@ -95,7 +95,6 @@ cpp::OpDesc ConvElementwiseAddActivationFuser::GenOpDesc(
"ResidualData"
)
!=
input_arg_names
.
end
())
{
"ResidualData"
)
!=
input_arg_names
.
end
())
{
op_desc
.
SetInput
(
"ResidualData"
,
desc
->
Input
(
"ResidualData"
));
op_desc
.
SetInput
(
"ResidualData"
,
desc
->
Input
(
"ResidualData"
));
}
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc
.
SetAttr
(
"strides"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
));
op_desc
.
SetAttr
(
"strides"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
));
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.cc
浏览文件 @
b950d3a3
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include "
paddle/fluid/lite/core/mir/fusion/
conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass_test.cc
浏览文件 @
b950d3a3
...
@@ -12,13 +12,13 @@
...
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
paddle/fluid/lite/core/mir/fusion/fc_fuse_pass_test.cc
浏览文件 @
b950d3a3
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "fc_fuse_pass.h"
#include "
paddle/fluid/lite/core/mir/fusion/
fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <vector>
...
...
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.cc
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
QuantDequantFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
std
::
unordered_set
<
std
::
string
>
quant_types
=
{
"fake_quantize_range_abs_max"
,
"fake_quantize_moving_average_abs_max"
};
std
::
unordered_set
<
std
::
string
>
quantized_op_types
=
{
"conv2d"
,
"mul"
,
"depthwise_conv2d"
};
for
(
auto
&
quant_type
:
quant_types
)
{
for
(
auto
&
op_type
:
quantized_op_types
)
{
for
(
int
i
=
6
;
i
>=
1
;
i
--
)
{
fusion
::
QuantDequantOpFuser
fuser
(
op_type
,
quant_type
,
i
);
fuser
(
graph
.
get
());
}
}
}
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_quant_dequant_fuse_pass
,
paddle
::
lite
::
mir
::
QuantDequantFusePass
);
paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_set>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
QuantDequantFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.cc
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
QuantDequantOpFuser
::
BuildPattern
()
{
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kQuantizedOpOutOffset
=
2
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
std
::
string
weight_name
=
""
;
if
(
op_type_
==
"conv2d"
||
op_type_
==
"depthwise_conv2d"
)
{
weight_name
=
"Filter"
;
}
else
{
weight_name
=
"Y"
;
}
auto
*
quant_op_input
=
VarNode
(
"quant_op_input"
)
->
assert_is_op_input
(
quant_type_
,
"X"
)
->
AsInput
();
auto
*
quant_op_in_scale
=
VarNode
(
"quant_op_in_scale"
)
->
assert_is_op_input
(
quant_type_
,
"InScale"
)
->
AsIntermediate
();
auto
*
quant_op
=
OpNode
(
"quant_op"
,
quant_type_
)
->
assert_is_op
(
quant_type_
)
->
AsIntermediate
();
auto
*
quant_op_out_scale
=
VarNode
(
"quant_op_out_scale"
)
->
assert_is_op_output
(
quant_type_
,
"OutScale"
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"Scale"
)
->
AsIntermediate
();
auto
*
quant_op_out
=
VarNode
(
"quant_op_out"
)
->
assert_is_op_output
(
quant_type_
,
"Out"
)
->
assert_is_op_input
(
op_type_
)
->
AsIntermediate
();
std
::
vector
<
PMNode
*>
nodes
;
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
.
push_back
(
VarNode
(
"quantized_op_weight"
+
std
::
to_string
(
i
))
->
assert_is_op_input
(
op_type_
,
weight_name
)
->
AsInput
());
nodes
.
push_back
(
OpNode
(
"quantized_op"
+
std
::
to_string
(
i
),
op_type_
)
->
assert_is_op
(
op_type_
)
->
AsIntermediate
());
nodes
.
push_back
(
VarNode
(
"quantized_op_out"
+
std
::
to_string
(
i
))
->
assert_is_op_output
(
op_type_
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"X"
)
->
AsIntermediate
());
nodes
.
push_back
(
OpNode
(
"dequant_op"
+
std
::
to_string
(
i
),
"fake_dequantize_max_abs"
)
->
assert_is_op
(
"fake_dequantize_max_abs"
)
->
AsIntermediate
());
nodes
.
push_back
(
VarNode
(
"dequant_op_out"
+
std
::
to_string
(
i
))
->
assert_is_op_output
(
"fake_dequantize_max_abs"
,
"Out"
)
->
AsOutput
());
}
quant_op
->
LinksFrom
({
quant_op_input
,
quant_op_in_scale
});
quant_op_out
->
LinksFrom
({
quant_op
});
quant_op_out_scale
->
LinksFrom
({
quant_op
});
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
LinksFrom
(
{
quant_op_out
,
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
]});
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]});
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
],
quant_op_out_scale
});
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]});
}
}
void
QuantDequantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
auto
*
quant_op_input
=
matched
.
at
(
"quant_op_input"
);
auto
*
quant_op_in_scale
=
matched
.
at
(
"quant_op_in_scale"
);
auto
*
quant_op
=
matched
.
at
(
"quant_op"
);
std
::
vector
<
Node
*>
nodes
;
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
.
push_back
(
matched
.
at
(
"quantized_op_weight"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"quantized_op"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"quantized_op_out"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"dequant_op"
+
std
::
to_string
(
i
)));
nodes
.
push_back
(
matched
.
at
(
"dequant_op_out"
+
std
::
to_string
(
i
)));
}
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
auto
*
scope
=
quant_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
()
->
valid_places
();
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
input_scale
=
input_scale_t
->
data
<
float
>
()[
0
];
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
float
max_range
=
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
weight_scale
=
(
range
*
range
)
/
max_range
;
cpp
::
OpDesc
op_desc
=
*
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
stmt
()
->
op_info
();
if
(
op_type_
==
"conv2d"
||
op_type_
==
"depthwise_conv2d"
)
{
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"quant_op_input"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
arg
()
->
name
});
}
else
if
(
op_type_
==
"mul"
)
{
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"quant_op_input"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
arg
()
->
name
});
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
auto
quantized_weight_var_name
=
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
]
->
arg
()
->
name
;
auto
quantized_weight_t
=
scope
->
FindVar
(
quantized_weight_var_name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
*
quantized_weight_data
=
quantized_weight_t
->
mutable_data
<
float
>
();
size_t
weight_num
=
quantized_weight_t
->
data_size
();
for
(
size_t
i
=
0
;
i
<
weight_num
;
i
++
)
{
quantized_weight_data
[
i
]
*=
(
weight_scale
/
range
);
}
auto
quantized_op
=
LiteOpRegistry
::
Global
().
Create
(
op_type_
);
quantized_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
quantized_op
,
valid_places
);
IR_NODE_LINK_TO
(
quant_op_input
,
new_op_node
);
IR_NODE_LINK_TO
(
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
],
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]);
}
}
cpp
::
OpDesc
QuantDequantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
/* The model trained by fluid quantization is a simulation of real int8.
* The quantized Ops(conv2d, mul, depthwise conv2d etc) have fake_quantop
* in front and fake_dequantop behind.
*
* When in int8 mode, the pattern like "fake_quant + quantized_op +
* fake_dequant"
* can be detected by this fuser. The fuser extract the input_scale and
* the weight_scale info from fake_quant, fake_dequant op and fuse those into
* the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last.
*/
class
QuantDequantOpFuser
:
public
FuseBase
{
public:
explicit
QuantDequantOpFuser
(
const
std
::
string
&
op_type
,
const
std
::
string
&
quant_type
,
int
times
)
:
op_type_
(
op_type
),
quant_type_
(
quant_type
),
times_
(
times
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
private:
std
::
string
op_type_
{
"conv2d"
};
std
::
string
quant_type_
;
int
times_
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pattern_matcher.cc
浏览文件 @
b950d3a3
...
@@ -115,7 +115,6 @@ void PatternMatcher::operator()(SSAGraph *graph,
...
@@ -115,7 +115,6 @@ void PatternMatcher::operator()(SSAGraph *graph,
bool
PatternMatcher
::
MarkPMNodesInGraph
(
SSAGraph
*
graph
)
{
bool
PatternMatcher
::
MarkPMNodesInGraph
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"mark pmnodes in graph"
;
VLOG
(
3
)
<<
"mark pmnodes in graph"
;
if
(
graph
->
nodes
().
empty
())
return
false
;
if
(
graph
->
nodes
().
empty
())
return
false
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
const
auto
&
pmnode
:
pattern_
.
nodes
())
{
for
(
const
auto
&
pmnode
:
pattern_
.
nodes
())
{
if
(
pmnode
->
Tell
(
&
node
))
{
if
(
pmnode
->
Tell
(
&
node
))
{
...
@@ -398,7 +397,7 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
...
@@ -398,7 +397,7 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
inlinks
)
{
for
(
auto
*
op
:
x
->
inlinks
)
{
if
(
op
&&
op
->
IsStmt
())
{
if
(
op
&&
op
->
IsStmt
())
{
auto
*
op_info
=
x
->
stmt
()
->
op_info
();
auto
*
op_info
=
op
->
stmt
()
->
op_info
();
if
(
op_info
->
Type
()
==
op_type
)
return
true
;
if
(
op_info
->
Type
()
==
op_type
)
return
true
;
}
}
}
}
...
...
paddle/fluid/lite/core/mir/use_passes.h
浏览文件 @
b950d3a3
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
#pragma once
#pragma once
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
...
@@ -23,12 +22,12 @@ USE_MIR_PASS(type_target_transform_pass);
...
@@ -23,12 +22,12 @@ USE_MIR_PASS(type_target_transform_pass);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
#endif
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
identity_scale_eliminate_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_quant_dequant_fuse_pass
);
paddle/fluid/lite/core/optimizer.h
浏览文件 @
b950d3a3
...
@@ -50,6 +50,7 @@ class Optimizer {
...
@@ -50,6 +50,7 @@ class Optimizer {
if
(
passes
.
empty
())
{
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_quant_dequant_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
...
@@ -57,18 +58,16 @@ class Optimizer {
...
@@ -57,18 +58,16 @@ class Optimizer {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#endif
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_fc_fuse_pass"
,
//
"static_kernel_pick_pass"
,
//
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"argument_type_display_pass"
,
//
"type_target_transform_pass"
,
//
"type_target_transform_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"argument_type_display_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"variable_place_inference_pass"
,
//
#endif
"runtime_context_assign_pass"
,
//
"runtime_context_assign_pass"
,
//
}});
}});
}
else
{
}
else
{
RunPasses
(
passes
);
RunPasses
(
passes
);
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
b950d3a3
...
@@ -55,8 +55,8 @@ enum class DataLayoutType : int {
...
@@ -55,8 +55,8 @@ enum class DataLayoutType : int {
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
static
const
std
::
string
&
TargetToStr
(
TargetType
target
)
{
static
const
std
::
string
&
TargetToStr
(
TargetType
target
)
{
static
const
std
::
string
target2string
[]
=
{
"unk"
,
"host"
,
"x86"
,
"cuda
"
,
static
const
std
::
string
target2string
[]
=
{
"unk"
,
"host"
,
"x86
"
,
"any"
};
"
cuda"
,
"arm"
,
"
any"
};
auto
x
=
static_cast
<
int
>
(
target
);
auto
x
=
static_cast
<
int
>
(
target
);
CHECK_LT
(
x
,
static_cast
<
int
>
(
TARGET
(
NUM
)));
CHECK_LT
(
x
,
static_cast
<
int
>
(
TARGET
(
NUM
)));
return
target2string
[
x
];
return
target2string
[
x
];
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
b950d3a3
...
@@ -165,8 +165,8 @@ class Type : public DataType {
...
@@ -165,8 +165,8 @@ class Type : public DataType {
// -------------------------------- compatible check ---------------------------
// -------------------------------- compatible check ---------------------------
static
bool
TargetCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
static
bool
TargetCompatibleTo
(
const
Type
&
a
,
const
Type
&
b
)
{
auto
is_host
=
[](
TargetType
x
)
{
auto
is_host
=
[](
TargetType
x
)
->
bool
{
return
x
==
TARGET
(
kHost
)
||
x
==
TARGET
(
kX86
);
return
x
==
TARGET
(
kHost
)
||
x
==
TARGET
(
kX86
)
||
x
==
TARGET
(
kARM
)
;
};
};
if
(
a
.
IsVoid
()
||
b
.
IsVoid
())
return
true
;
if
(
a
.
IsVoid
()
||
b
.
IsVoid
())
return
true
;
if
(
a
.
IsTensor
()
||
b
.
IsTensor
())
{
if
(
a
.
IsTensor
()
||
b
.
IsTensor
())
{
...
...
paddle/fluid/lite/kernels/arm/conv_compute.cc
浏览文件 @
b950d3a3
...
@@ -100,7 +100,7 @@ void ConvCompute::Run() {
...
@@ -100,7 +100,7 @@ void ConvCompute::Run() {
REGISTER_LITE_KERNEL
(
conv2d
,
kARM
,
kFloat
,
kNCHW
,
REGISTER_LITE_KERNEL
(
conv2d
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
//
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
.
Finalize
();
...
@@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
...
@@ -108,7 +108,7 @@ REGISTER_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW,
REGISTER_LITE_KERNEL
(
depthwise_conv2d
,
kARM
,
kFloat
,
kNCHW
,
REGISTER_LITE_KERNEL
(
depthwise_conv2d
,
kARM
,
kFloat
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
paddle
::
lite
::
kernels
::
arm
::
ConvCompute
,
def
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
//
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.
BindInput
(
"Bias"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindInput
(
"Filter"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
BindOutput
(
"Output"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
))})
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
b950d3a3
...
@@ -23,6 +23,8 @@ cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
...
@@ -23,6 +23,8 @@ cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library
(
concat_op_lite SRCS concat_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
concat_op_lite SRCS concat_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
split_op_lite SRCS split_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
split_op_lite SRCS split_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
transpose_op_lite SRCS transpose_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
transpose_op_lite SRCS transpose_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fake_quant SRCS fake_quantize_moving_avg_max_abs.cc DEPS
${
op_DEPS
}
)
cc_library
(
fake_dequant SRCS fake_dequantize_max_abs.cc DEPS
${
op_DEPS
}
)
set
(
ops_lite
set
(
ops_lite
conv_op_lite
conv_op_lite
...
@@ -46,6 +48,8 @@ set(ops_lite
...
@@ -46,6 +48,8 @@ set(ops_lite
concat_op_lite
concat_op_lite
split_op_lite
split_op_lite
transpose_op_lite
transpose_op_lite
fake_quant
fake_dequant
PARENT_SCOPE
)
PARENT_SCOPE
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
...
...
paddle/fluid/lite/operators/fake_dequantize_max_abs.cc
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fake_dequantize_max_abs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fake_dequantize_max_abs
,
paddle
::
lite
::
operators
::
FakeDequantizeMaxAbsOpLite
);
paddle/fluid/lite/operators/fake_dequantize_max_abs.h
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FakeDequantizeMaxAbsOpLite
:
public
OpLite
{
public:
FakeDequantizeMaxAbsOpLite
()
{}
explicit
FakeDequantizeMaxAbsOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
in_scale
=
op_desc
.
Input
(
"Scale"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_scale
=
scope
->
FindVar
(
in_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
max_range
=
op_desc
.
GetAttr
<
float
>
(
"max_range"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"fake_dequantize_max_abs"
;
}
private:
mutable
FakeDequantizeMaxAbsParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.cc
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fake_quantize_moving_average_abs_max
,
paddle
::
lite
::
operators
::
FakeQuantizeMovingAvgMaxAbsOpLite
);
paddle/fluid/lite/operators/fake_quantize_moving_avg_max_abs.h
0 → 100644
浏览文件 @
b950d3a3
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FakeQuantizeMovingAvgMaxAbsOpLite
:
public
OpLite
{
public:
FakeQuantizeMovingAvgMaxAbsOpLite
()
{}
explicit
FakeQuantizeMovingAvgMaxAbsOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
in_scale
=
op_desc
.
Input
(
"InScale"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out_scale
=
op_desc
.
Output
(
"OutScale"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_scale
=
scope
->
FindVar
(
in_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out_scale
=
scope
->
FindVar
(
out_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
bit_length
=
op_desc
.
GetAttr
<
int
>
(
"bit_length"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"fake_quantize_moving_avg_max_abs"
;
}
private:
mutable
FakeQuantizeMovingAvgMaxAbsParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/op_params.h
浏览文件 @
b950d3a3
...
@@ -273,6 +273,28 @@ struct FillConstantParam {
...
@@ -273,6 +273,28 @@ struct FillConstantParam {
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
//
struct
FakeQuantizeMovingAvgMaxAbsParam
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
const
lite
::
Tensor
*
in_accum
{};
const
lite
::
Tensor
*
in_state
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out_scale
{};
lite
::
Tensor
*
out_state
{};
lite
::
Tensor
*
out_accum
{};
int
bit_length
;
bool
is_test
{
true
};
float
moving_rate
{
0.9
};
};
struct
FakeDequantizeMaxAbsParam
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
lite
::
Tensor
*
out
{};
float
max_range
;
};
/// ----------------------- sgd operators ----------------------
/// ----------------------- sgd operators ----------------------
struct
SGDParam
{
struct
SGDParam
{
int
dtype
{
framework
::
proto
::
VarType
::
FP32
};
int
dtype
{
framework
::
proto
::
VarType
::
FP32
};
...
...
paddle/fluid/lite/operators/softmax_op.cc
浏览文件 @
b950d3a3
...
@@ -39,7 +39,12 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
...
@@ -39,7 +39,12 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
param_
.
output
=
param_
.
output
=
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
if
(
opdesc
.
HasAttr
(
"axis"
))
{
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
}
else
{
param_
.
axis
=
-
1
;
}
CHECK
(
param_
.
x
);
CHECK
(
param_
.
x
);
CHECK
(
param_
.
output
);
CHECK
(
param_
.
output
);
return
true
;
return
true
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录