Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
de2a72a4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
de2a72a4
编写于
6月 17, 2019
作者:
S
sangoly
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add elementwise_add_activation fuse pass & op
上级
09b35192
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
683 addition
and
113 deletion
+683
-113
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+8
-3
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
...ite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
+8
-7
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
...lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc
...ore/mir/conv_elementwise_add_activation_fuse_pass_test.cc
+4
-4
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
...uid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
+36
-0
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h
...luid/lite/core/mir/elementwise_add_activation_fuse_pass.h
+1
-1
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
...ite/core/mir/elementwise_add_activation_fuse_pass_test.cc
+117
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+7
-3
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+13
-9
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h
...e/core/mir/fusion/conv_elementwise_add_activation_fuser.h
+47
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
.../lite/core/mir/fusion/elementwise_add_activation_fuser.cc
+87
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h
...d/lite/core/mir/fusion/elementwise_add_activation_fuser.h
+4
-4
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+2
-1
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+6
-3
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+5
-0
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+53
-78
paddle/fluid/lite/operators/elementwise_ops.h
paddle/fluid/lite/operators/elementwise_ops.h
+65
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc
...fluid/lite/operators/fusion_elementwise_activation_ops.cc
+57
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h
.../fluid/lite/operators/fusion_elementwise_activation_ops.h
+60
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc
.../lite/operators/fusion_elementwise_activation_ops_test.cc
+63
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+8
-0
未找到文件。
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
de2a72a4
...
...
@@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
...
...
@@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
lite_cc_test
(
test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.cc
→
paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.cc
浏览文件 @
de2a72a4
...
...
@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAdd
ReLU
FusePass
::
Apply
(
void
ConvElementwiseAdd
Activation
FusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAdd
ReLUFuser
fuser
(
"conv2d
"
);
fusion
::
ConvElementwiseAdd
ActivationFuser
fuser
(
"conv2d"
,
"relu
"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
fusion
::
ConvElementwiseAddActivationFuser
depthwise_fuser
(
"depthwise_conv2d"
,
"relu"
);
depthwise_fuser
(
graph
.
get
());
}
...
...
@@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply(
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAdd
ReLU
FusePass
);
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act
ivation
_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAdd
Activation
FusePass
);
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddActivationFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass_test.cc
浏览文件 @
de2a72a4
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
@@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAdd
ReLU
FusePass
;
auto
*
fuser
=
new
ConvElementwiseAdd
Activation
FusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node
*/
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused nodes
*/
);
}
}
// namespace fusion
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ElementwiseAddActivationFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ElementwiseAddActivationFuser
fuser
(
"relu"
);
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
,
paddle
::
lite
::
mir
::
ElementwiseAddActivationFusePass
);
paddle/fluid/lite/core/mir/
conv_elementwise_add_relu
_fuse_pass.h
→
paddle/fluid/lite/core/mir/
elementwise_add_activation
_fuse_pass.h
浏览文件 @
de2a72a4
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLU
FusePass
:
public
ProgramPass
{
class
ElementwiseAddActivation
FusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"x_1"
);
main_block
->
Var
(
"y_1"
);
main_block
->
Var
(
"add_out_1"
);
main_block
->
Var
(
"relu_out_1"
);
main_block
->
Var
(
"y_2"
);
main_block
->
Var
(
"add_out_2"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"x_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"y_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_out_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_out_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"y_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_out_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"x_1"
});
add_1
->
SetInput
(
"Y"
,
{
"y_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_out_1"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_out_1"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_out_1"
});
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"relu_out_1"
});
add_2
->
SetInput
(
"Y"
,
{
"y_2"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_out_2"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_out_2"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
elementwise_add_activation_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
7UL
/*vars*/
+
4UL
/*ops*/
+
1UL
/* SSAGraph tmp node*/
);
}
TEST
(
elementwise_add_activation_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ElementwiseAddActivationFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
3UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused nodes*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
fusion_elementwise_add_activation
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
de2a72a4
cc_library
(
fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_elementwise_add_
relu
SRCS conv_elementwise_add_
relu
_fuser.cc
cc_library
(
fuse_conv_elementwise_add_
activation
SRCS conv_elementwise_add_
activation
_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
fuse_conv_elementwise_add_
relu
fuse_conv_elementwise_add_
activation
fuse_conv_bn
fuse_elementwise_add_activation
CACHE INTERNAL
"fusers"
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.cc
→
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.cc
浏览文件 @
de2a72a4
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.h"
#include <memory>
#include <vector>
...
...
@@ -21,7 +21,7 @@ namespace lite {
namespace
mir
{
namespace
fusion
{
void
ConvElementwiseAdd
ReLU
Fuser
::
BuildPattern
()
{
void
ConvElementwiseAdd
Activation
Fuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
...
...
@@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
assert_is_op
(
"relu"
)
->
AsIntermediate
();
auto
*
act
=
OpNode
(
"act"
,
act_type_
)
->
assert_is_op
(
act_type_
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
)
...
...
@@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
assert_is_op_input
(
act_type_
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
();
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
std
::
vector
<
PMNode
*>
add_inputs
{
conv2d_out
,
bias
};
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
*
add_out
>>
*
act
>>
*
out
;
}
void
ConvElementwiseAdd
ReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
ConvElementwiseAdd
ActivationFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
...
...
@@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ConvElementwiseAddReLUFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
ConvElementwiseAddActivationFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
...
...
@@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"groups"
,
desc
->
GetAttr
<
int
>
(
"groups"
));
op_desc
.
SetAttr
(
"dilations"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
));
// TODO(sangoly): support other activation types
op_desc
.
SetAttr
(
"fuse_relu"
,
true
);
return
op_desc
;
}
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddActivationFuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddActivationFuser
(
const
std
::
string
&
conv_type
,
const
std
::
string
&
act_type
)
{
CHECK
(
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
conv_type_
=
conv_type
;
act_type_
=
act_type
;
}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv_type_
;
std
::
string
act_type_
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ElementwiseAddActivationFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsInput
();
auto
*
y
=
VarNode
(
"y"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
// create op nodes
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
act
=
OpNode
(
"act"
,
act_type_
)
->
assert_is_op
(
act_type_
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
act_type_
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
add_inputs
{
x
,
y
};
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
act
>>
*
out
;
}
void
ElementwiseAddActivationFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
;
auto
*
scope
=
old_op
->
scope
();
auto
&
valid_places
=
old_op
->
valid_places
();
op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"x"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"y"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ElementwiseAddActivationFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"add"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_elementwise_add_activation"
);
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"x"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Y"
,
{
matched
.
at
(
"y"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"output"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"axis"
,
desc
->
GetAttr
<
int
>
(
"axis"
));
op_desc
.
SetAttr
(
"act_type"
,
act_type_
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/
conv_elementwise_add_relu
_fuser.h
→
paddle/fluid/lite/core/mir/fusion/
elementwise_add_activation
_fuser.h
浏览文件 @
de2a72a4
...
...
@@ -23,16 +23,16 @@ namespace lite {
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddReLU
Fuser
:
public
FuseBase
{
class
ElementwiseAddActivation
Fuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddReLUFuser
(
const
std
::
string
&
conv
_type
)
:
conv_type_
(
conv
_type
)
{}
explicit
ElementwiseAddActivationFuser
(
const
std
::
string
&
act
_type
)
:
act_type_
(
act
_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv
_type_
;
std
::
string
act
_type_
;
};
}
// namespace fusion
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
de2a72a4
...
...
@@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
paddle/fluid/lite/core/optimizer.h
浏览文件 @
de2a72a4
...
...
@@ -48,9 +48,12 @@ class Optimizer {
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_act_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_activation_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
de2a72a4
...
...
@@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
activation_ops_lite SRCS activation_ops.cc DEPS
${
op_DEPS
}
)
cc_library
(
elementwise_ops_lite SRCS elementwise_ops.cc DEPS
${
op_DEPS
}
)
cc_library
(
fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite
${
op_DEPS
}
)
cc_library
(
mean_op_lite SRCS mean_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fill_constant_op_lite SRCS fill_constant_op.cc DEPS
${
op_DEPS
}
)
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
...
...
@@ -36,6 +37,7 @@ set(ops_lite
fetch_op_lite
io_copy_op_lite
elementwise_ops_lite
fusion_elementwise_activation_ops_lite
mean_op_lite
fill_constant_op_lite
activation_ops_lite
...
...
@@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
lite_cc_test
(
test_fusion_elementwise_activation_ops_lite
SRCS fusion_elementwise_activation_ops_test.cc
DEPS fusion_elementwise_activation_ops_lite memory_lite
)
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
de2a72a4
...
...
@@ -12,92 +12,67 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/
core/op_lite
.h"
#include "paddle/fluid/lite/
operators/elementwise_ops
.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
ElementwiseOp
:
public
OpLite
{
public:
explicit
ElementwiseOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
InferShape
()
const
override
{
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_op"
;
}
private:
mutable
operators
::
ElementwiseParam
param_
;
};
bool
ElementwiseOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
ElementwiseOp
::
InferShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
bool
ElementwiseOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
X
=
GetVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Y_name
);
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
#ifdef LITE_WITH_X86
class
ElementwiseGradExplicitOp
:
public
OpLite
{
public:
explicit
ElementwiseGradExplicitOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
X_grad
);
CHECK_OR_FALSE
(
param_
.
Y_grad
);
CHECK_OR_FALSE
(
param_
.
Out_grad
);
return
true
;
}
bool
InferShape
()
const
override
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
param_
.
Y_grad
->
Resize
(
param_
.
Y
->
dims
());
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_grad_explicit_op"
;
}
private:
mutable
operators
::
ElementwiseGradParam
param_
;
};
bool
ElementwiseGradExplicitOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
X_grad
);
CHECK_OR_FALSE
(
param_
.
Y_grad
);
CHECK_OR_FALSE
(
param_
.
Out_grad
);
return
true
;
}
bool
ElementwiseGradExplicitOp
::
InferShape
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
param_
.
Y_grad
->
Resize
(
param_
.
Y
->
dims
());
return
true
;
}
bool
ElementwiseGradExplicitOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
auto
Y_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"Y"
)).
front
();
param_
.
Out_grad
=
GetVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
X_grad
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
X_name
);
param_
.
Y_grad
=
GetMutableVar
<
Tensor
>
(
scope
,
Y_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
#endif
}
// namespace operators
...
...
paddle/fluid/lite/operators/elementwise_ops.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
ElementwiseOp
:
public
OpLite
{
public:
explicit
ElementwiseOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_op"
;
}
private:
mutable
operators
::
ElementwiseParam
param_
;
};
#ifdef LITE_WITH_X86
class
ElementwiseGradExplicitOp
:
public
OpLite
{
public:
explicit
ElementwiseGradExplicitOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_grad_explicit_op"
;
}
private:
mutable
operators
::
ElementwiseGradParam
param_
;
};
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <string>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
FusionElementwiseActivationOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ElementwiseOp
::
AttachImpl
(
opdesc
,
scope
);
param_
.
act_type
=
opdesc
.
GetAttr
<
std
::
string
>
(
"act_type"
);
// TODO(sangoly): support more activation types.
CHECK
(
param_
.
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
return
true
;
}
#ifdef LITE_WITH_X86
bool
FusionElementwiseActivationGradExplicitOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ElementwiseGradExplicitOp
::
AttachImpl
(
opdesc
,
scope
);
param_
.
act_type
=
opdesc
.
GetAttr
<
std
::
string
>
(
"act_type"
);
// TODO(sangoly): support more activation types.
CHECK
(
param_
.
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
return
true
;
}
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fusion_elementwise_sub_activation
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationOp
);
#ifdef LITE_WITH_X86
REGISTER_LITE_OP
(
fusion_elementwise_sub_activation_grad
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationGradExplicitOp
);
#endif
REGISTER_LITE_OP
(
fusion_elementwise_add_activation
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationOp
);
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/operators/elementwise_ops.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FusionElementwiseActivationOp
:
public
ElementwiseOp
{
public:
explicit
FusionElementwiseActivationOp
(
const
std
::
string
&
type
)
:
ElementwiseOp
(
type
)
{}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"fusion_elementwise_activation_op"
;
}
private:
mutable
operators
::
FusionElementwiseActivationParam
param_
;
};
#ifdef LITE_WITH_X86
class
FusionElementwiseActivationGradExplicitOp
:
public
ElementwiseGradExplicitOp
{
public:
explicit
FusionElementwiseActivationGradExplicitOp
(
const
std
::
string
&
type
)
:
ElementwiseGradExplicitOp
(
type
)
{}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"fusion_elementwise_activation_grad_explicit_op"
;
}
private:
mutable
operators
::
FusionElementwiseActivationGradParam
param_
;
};
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
TEST
(
fusion_elementwise_activation_op_lite
,
test
)
{
// prepare variables
lite
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
*
out
=
scope
.
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
x
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
({
10
,
20
})));
y
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
({
10
,
20
})));
out
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
{
10
,
20
}));
// set data
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
x
->
mutable_data
<
float
>
()[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
y
->
mutable_data
<
float
>
()[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
out
->
mutable_data
<
float
>
()[
i
]
=
0.
;
}
// prepare op desc
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"fusion_elementwise_add_activation"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"Y"
,
{
"y"
});
desc
.
SetOutput
(
"Out"
,
{
"out"
});
desc
.
SetAttr
(
"axis"
,
static_cast
<
int
>
(
1
));
desc
.
SetAttr
(
"act_type"
,
std
::
string
(
"relu"
));
FusionElementwiseActivationOp
fuse_op
(
"fusion_elementwise_add_activation"
);
fuse_op
.
SetValidPlaces
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
fuse_op
.
Attach
(
desc
,
&
scope
);
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/op_params.h
浏览文件 @
de2a72a4
...
...
@@ -219,6 +219,14 @@ struct ElementwiseGradParam {
int
axis
{
-
1
};
// for broadcasting.
};
struct
FusionElementwiseActivationParam
:
public
ElementwiseParam
{
std
::
string
act_type
;
};
struct
FusionElementwiseActivationGradParam
:
public
ElementwiseGradParam
{
std
::
string
act_type
;
};
/// ----------------------- activation operators ----------------------
struct
ActivationParam
{
const
lite
::
Tensor
*
X
{};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录