Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
de2a72a4
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
de2a72a4
编写于
6月 17, 2019
作者:
S
sangoly
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add elementwise_add_activation fuse pass & op
上级
09b35192
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
683 addition
and
113 deletion
+683
-113
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+8
-3
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
...ite/core/mir/conv_elementwise_add_activation_fuse_pass.cc
+8
-7
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
...lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass_test.cc
...ore/mir/conv_elementwise_add_activation_fuse_pass_test.cc
+4
-4
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
...uid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
+36
-0
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h
...luid/lite/core/mir/elementwise_add_activation_fuse_pass.h
+1
-1
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
...ite/core/mir/elementwise_add_activation_fuse_pass_test.cc
+117
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+7
-3
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.cc
.../core/mir/fusion/conv_elementwise_add_activation_fuser.cc
+13
-9
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h
...e/core/mir/fusion/conv_elementwise_add_activation_fuser.h
+47
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
.../lite/core/mir/fusion/elementwise_add_activation_fuser.cc
+87
-0
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h
...d/lite/core/mir/fusion/elementwise_add_activation_fuser.h
+4
-4
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+2
-1
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+6
-3
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+5
-0
paddle/fluid/lite/operators/elementwise_ops.cc
paddle/fluid/lite/operators/elementwise_ops.cc
+53
-78
paddle/fluid/lite/operators/elementwise_ops.h
paddle/fluid/lite/operators/elementwise_ops.h
+65
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc
...fluid/lite/operators/fusion_elementwise_activation_ops.cc
+57
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h
.../fluid/lite/operators/fusion_elementwise_activation_ops.h
+60
-0
paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc
.../lite/operators/fusion_elementwise_activation_ops_test.cc
+63
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+8
-0
未找到文件。
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
de2a72a4
...
...
@@ -7,7 +7,8 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
...
...
@@ -82,7 +83,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
lite_cc_test
(
test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
lite_cc_test
(
test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
lite_cc_test
(
test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.cc
→
paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.cc
浏览文件 @
de2a72a4
...
...
@@ -12,22 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAdd
ReLU
FusePass
::
Apply
(
void
ConvElementwiseAdd
Activation
FusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAdd
ReLUFuser
fuser
(
"conv2d
"
);
fusion
::
ConvElementwiseAdd
ActivationFuser
fuser
(
"conv2d"
,
"relu
"
);
fuser
(
graph
.
get
());
fusion
::
ConvElementwiseAddReLUFuser
depthwise_fuser
(
"depthwise_conv2d"
);
fusion
::
ConvElementwiseAddActivationFuser
depthwise_fuser
(
"depthwise_conv2d"
,
"relu"
);
depthwise_fuser
(
graph
.
get
());
}
...
...
@@ -35,5 +36,5 @@ void ConvElementwiseAddReLUFusePass::Apply(
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAdd
ReLU
FusePass
);
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act
ivation
_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAdd
Activation
FusePass
);
paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddActivationFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass_test.cc
→
paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass_test.cc
浏览文件 @
de2a72a4
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
relu
_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_
activation
_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
...
...
@@ -135,11 +135,11 @@ TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAdd
ReLU
FusePass
;
auto
*
fuser
=
new
ConvElementwiseAdd
Activation
FusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node
*/
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused nodes
*/
);
}
}
// namespace fusion
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ElementwiseAddActivationFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ElementwiseAddActivationFuser
fuser
(
"relu"
);
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
,
paddle
::
lite
::
mir
::
ElementwiseAddActivationFusePass
);
paddle/fluid/lite/core/mir/
conv_elementwise_add_relu
_fuse_pass.h
→
paddle/fluid/lite/core/mir/
elementwise_add_activation
_fuse_pass.h
浏览文件 @
de2a72a4
...
...
@@ -22,7 +22,7 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLU
FusePass
:
public
ProgramPass
{
class
ElementwiseAddActivation
FusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
...
...
paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass_test.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"x_1"
);
main_block
->
Var
(
"y_1"
);
main_block
->
Var
(
"add_out_1"
);
main_block
->
Var
(
"relu_out_1"
);
main_block
->
Var
(
"y_2"
);
main_block
->
Var
(
"add_out_2"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"x_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"y_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_out_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_out_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"y_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_out_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"x_1"
});
add_1
->
SetInput
(
"Y"
,
{
"y_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_out_1"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"X"
,
{
"add_out_1"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_out_1"
});
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"relu_out_1"
});
add_2
->
SetInput
(
"Y"
,
{
"y_2"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_out_2"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"X"
,
{
"add_out_2"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
elementwise_add_activation_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
7UL
/*vars*/
+
4UL
/*ops*/
+
1UL
/* SSAGraph tmp node*/
);
}
TEST
(
elementwise_add_activation_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ElementwiseAddActivationFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
3UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused nodes*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
fusion_elementwise_add_activation
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
de2a72a4
cc_library
(
fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_elementwise_add_
relu
SRCS conv_elementwise_add_
relu
_fuser.cc
cc_library
(
fuse_conv_elementwise_add_
activation
SRCS conv_elementwise_add_
activation
_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
fuse_conv_elementwise_add_
relu
fuse_conv_elementwise_add_
activation
fuse_conv_bn
fuse_elementwise_add_activation
CACHE INTERNAL
"fusers"
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.cc
→
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.cc
浏览文件 @
de2a72a4
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
relu
_fuser.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_
activation
_fuser.h"
#include <memory>
#include <vector>
...
...
@@ -21,7 +21,7 @@ namespace lite {
namespace
mir
{
namespace
fusion
{
void
ConvElementwiseAdd
ReLU
Fuser
::
BuildPattern
()
{
void
ConvElementwiseAdd
Activation
Fuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
)
->
assert_is_op_input
(
conv_type_
,
"Input"
)
->
AsInput
();
...
...
@@ -36,7 +36,8 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
)
->
assert_is_op
(
"relu"
)
->
AsIntermediate
();
auto
*
act
=
OpNode
(
"act"
,
act_type_
)
->
assert_is_op
(
act_type_
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
)
...
...
@@ -45,22 +46,23 @@ void ConvElementwiseAddReLUFuser::BuildPattern() {
->
AsIntermediate
();
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"relu"
,
"X"
)
->
assert_is_op_input
(
act_type_
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
"relu"
,
"Out"
)
->
AsOutput
();
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
std
::
vector
<
PMNode
*>
add_inputs
{
conv2d_out
,
bias
};
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
*
add_out
>>
*
act
>>
*
out
;
}
void
ConvElementwiseAdd
ReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
void
ConvElementwiseAdd
ActivationFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
conv_type_
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
...
...
@@ -76,7 +78,8 @@ void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ConvElementwiseAddReLUFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
ConvElementwiseAddActivationFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
...
...
@@ -98,6 +101,7 @@ cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"groups"
,
desc
->
GetAttr
<
int
>
(
"groups"
));
op_desc
.
SetAttr
(
"dilations"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
));
// TODO(sangoly): support other activation types
op_desc
.
SetAttr
(
"fuse_relu"
,
true
);
return
op_desc
;
}
...
...
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddActivationFuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddActivationFuser
(
const
std
::
string
&
conv_type
,
const
std
::
string
&
act_type
)
{
CHECK
(
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
conv_type_
=
conv_type
;
act_type_
=
act_type
;
}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv_type_
;
std
::
string
act_type_
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ElementwiseAddActivationFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
AsInput
();
auto
*
y
=
VarNode
(
"y"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
AsInput
();
// create op nodes
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
)
->
assert_is_op
(
"elementwise_add"
)
->
AsIntermediate
();
auto
*
act
=
OpNode
(
"act"
,
act_type_
)
->
assert_is_op
(
act_type_
)
->
AsIntermediate
();
// create intermediate nodes
auto
*
add_out
=
VarNode
(
"add_out"
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
act_type_
,
"X"
)
->
AsIntermediate
();
// create output node
auto
*
out
=
VarNode
(
"output"
)
->
assert_is_op_output
(
act_type_
,
"Out"
)
->
AsOutput
();
// create topology.
std
::
vector
<
PMNode
*>
add_inputs
{
x
,
y
};
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
act
>>
*
out
;
}
void
ElementwiseAddActivationFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
op
=
LiteOpRegistry
::
Global
().
Create
(
"fusion_elementwise_add_activation"
);
auto
old_op
=
matched
.
at
(
"add"
)
->
stmt
()
->
op
;
auto
*
scope
=
old_op
->
scope
();
auto
&
valid_places
=
old_op
->
valid_places
();
op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"x"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"y"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ElementwiseAddActivationFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"add"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"fusion_elementwise_add_activation"
);
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"x"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Y"
,
{
matched
.
at
(
"y"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"output"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"axis"
,
desc
->
GetAttr
<
int
>
(
"axis"
));
op_desc
.
SetAttr
(
"act_type"
,
act_type_
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/
conv_elementwise_add_relu
_fuser.h
→
paddle/fluid/lite/core/mir/fusion/
elementwise_add_activation
_fuser.h
浏览文件 @
de2a72a4
...
...
@@ -23,16 +23,16 @@ namespace lite {
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddReLU
Fuser
:
public
FuseBase
{
class
ElementwiseAddActivation
Fuser
:
public
FuseBase
{
public:
explicit
ConvElementwiseAddReLUFuser
(
const
std
::
string
&
conv
_type
)
:
conv_type_
(
conv
_type
)
{}
explicit
ElementwiseAddActivationFuser
(
const
std
::
string
&
act
_type
)
:
act_type_
(
act
_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
std
::
string
conv
_type_
;
std
::
string
act
_type_
;
};
}
// namespace fusion
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
de2a72a4
...
...
@@ -34,4 +34,5 @@ USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
graph_visualze
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_activation_fuse_pass
);
USE_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
);
paddle/fluid/lite/core/optimizer.h
浏览文件 @
de2a72a4
...
...
@@ -49,8 +49,11 @@ class Optimizer {
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_conv_bn_fuse_pass"
,
//
"lite_conv_elementwise_add_act_fuse_pass"
,
//
"lite_conv_elementwise_add_act
ivation
_fuse_pass"
,
//
"lite_fc_fuse_pass"
,
//
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
de2a72a4
...
...
@@ -14,6 +14,7 @@ cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
activation_ops_lite SRCS activation_ops.cc DEPS
${
op_DEPS
}
)
cc_library
(
elementwise_ops_lite SRCS elementwise_ops.cc DEPS
${
op_DEPS
}
)
cc_library
(
fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite
${
op_DEPS
}
)
cc_library
(
mean_op_lite SRCS mean_op.cc DEPS
${
op_DEPS
}
)
cc_library
(
fill_constant_op_lite SRCS fill_constant_op.cc DEPS
${
op_DEPS
}
)
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
...
...
@@ -36,6 +37,7 @@ set(ops_lite
fetch_op_lite
io_copy_op_lite
elementwise_ops_lite
fusion_elementwise_activation_ops_lite
mean_op_lite
fill_constant_op_lite
activation_ops_lite
...
...
@@ -56,3 +58,6 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test
(
test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite
)
lite_cc_test
(
test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite
)
lite_cc_test
(
test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite
)
lite_cc_test
(
test_fusion_elementwise_activation_ops_lite
SRCS fusion_elementwise_activation_ops_test.cc
DEPS fusion_elementwise_activation_ops_lite memory_lite
)
paddle/fluid/lite/operators/elementwise_ops.cc
浏览文件 @
de2a72a4
...
...
@@ -12,31 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/
core/op_lite
.h"
#include "paddle/fluid/lite/
operators/elementwise_ops
.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
ElementwiseOp
:
public
OpLite
{
public:
explicit
ElementwiseOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
bool
ElementwiseOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
}
bool
InferShape
()
const
override
{
bool
ElementwiseOp
::
InferShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
ElementwiseOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
X_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
Y_name
=
opdesc
.
Input
(
"Y"
).
front
();
auto
Out_name
=
opdesc
.
Output
(
"Out"
).
front
();
...
...
@@ -46,36 +42,25 @@ class ElementwiseOp : public OpLite {
param_
.
Out
=
GetMutableVar
<
lite
::
Tensor
>
(
scope
,
Out_name
);
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_op"
;
}
private:
mutable
operators
::
ElementwiseParam
param_
;
};
}
#ifdef LITE_WITH_X86
class
ElementwiseGradExplicitOp
:
public
OpLite
{
public:
explicit
ElementwiseGradExplicitOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
bool
ElementwiseGradExplicitOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
Y
);
CHECK_OR_FALSE
(
param_
.
X_grad
);
CHECK_OR_FALSE
(
param_
.
Y_grad
);
CHECK_OR_FALSE
(
param_
.
Out_grad
);
return
true
;
}
}
bool
InferShape
()
const
override
{
bool
ElementwiseGradExplicitOp
::
InferShape
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
param_
.
Y_grad
->
Resize
(
param_
.
Y
->
dims
());
return
true
;
}
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
ElementwiseGradExplicitOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK_EQ
(
opdesc
.
InputArgumentNames
().
size
(),
1UL
);
auto
Out_name
=
opdesc
.
Input
(
framework
::
GradVarName
(
"Out"
)).
front
();
auto
X_name
=
opdesc
.
Output
(
framework
::
GradVarName
(
"X"
)).
front
();
...
...
@@ -87,17 +72,7 @@ class ElementwiseGradExplicitOp : public OpLite {
param_
.
axis
=
opdesc
.
GetAttr
<
int
>
(
"axis"
);
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_grad_explicit_op"
;
}
private:
mutable
operators
::
ElementwiseGradParam
param_
;
};
}
#endif
}
// namespace operators
...
...
paddle/fluid/lite/operators/elementwise_ops.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
ElementwiseOp
:
public
OpLite
{
public:
explicit
ElementwiseOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_op"
;
}
private:
mutable
operators
::
ElementwiseParam
param_
;
};
#ifdef LITE_WITH_X86
class
ElementwiseGradExplicitOp
:
public
OpLite
{
public:
explicit
ElementwiseGradExplicitOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"elementwise_grad_explicit_op"
;
}
private:
mutable
operators
::
ElementwiseGradParam
param_
;
};
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <string>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
FusionElementwiseActivationOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ElementwiseOp
::
AttachImpl
(
opdesc
,
scope
);
param_
.
act_type
=
opdesc
.
GetAttr
<
std
::
string
>
(
"act_type"
);
// TODO(sangoly): support more activation types.
CHECK
(
param_
.
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
return
true
;
}
#ifdef LITE_WITH_X86
bool
FusionElementwiseActivationGradExplicitOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ElementwiseGradExplicitOp
::
AttachImpl
(
opdesc
,
scope
);
param_
.
act_type
=
opdesc
.
GetAttr
<
std
::
string
>
(
"act_type"
);
// TODO(sangoly): support more activation types.
CHECK
(
param_
.
act_type
==
"relu"
)
<<
"Only relu activation be supported now"
;
return
true
;
}
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fusion_elementwise_sub_activation
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationOp
);
#ifdef LITE_WITH_X86
REGISTER_LITE_OP
(
fusion_elementwise_sub_activation_grad
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationGradExplicitOp
);
#endif
REGISTER_LITE_OP
(
fusion_elementwise_add_activation
,
paddle
::
lite
::
operators
::
FusionElementwiseActivationOp
);
paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/operators/elementwise_ops.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FusionElementwiseActivationOp
:
public
ElementwiseOp
{
public:
explicit
FusionElementwiseActivationOp
(
const
std
::
string
&
type
)
:
ElementwiseOp
(
type
)
{}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"fusion_elementwise_activation_op"
;
}
private:
mutable
operators
::
FusionElementwiseActivationParam
param_
;
};
#ifdef LITE_WITH_X86
class
FusionElementwiseActivationGradExplicitOp
:
public
ElementwiseGradExplicitOp
{
public:
explicit
FusionElementwiseActivationGradExplicitOp
(
const
std
::
string
&
type
)
:
ElementwiseGradExplicitOp
(
type
)
{}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"fusion_elementwise_activation_grad_explicit_op"
;
}
private:
mutable
operators
::
FusionElementwiseActivationGradParam
param_
;
};
#endif
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fusion_elementwise_activation_ops_test.cc
0 → 100644
浏览文件 @
de2a72a4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fusion_elementwise_activation_ops.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
TEST
(
fusion_elementwise_activation_op_lite
,
test
)
{
// prepare variables
lite
::
Scope
scope
;
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
*
y
=
scope
.
Var
(
"y"
)
->
GetMutable
<
lite
::
Tensor
>
();
auto
*
out
=
scope
.
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
x
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
({
10
,
20
})));
y
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
({
10
,
20
})));
out
->
Resize
(
lite
::
DDim
(
std
::
vector
<
int64_t
>
{
10
,
20
}));
// set data
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
x
->
mutable_data
<
float
>
()[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
y
->
mutable_data
<
float
>
()[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
out
->
mutable_data
<
float
>
()[
i
]
=
0.
;
}
// prepare op desc
cpp
::
OpDesc
desc
;
desc
.
SetType
(
"fusion_elementwise_add_activation"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"Y"
,
{
"y"
});
desc
.
SetOutput
(
"Out"
,
{
"out"
});
desc
.
SetAttr
(
"axis"
,
static_cast
<
int
>
(
1
));
desc
.
SetAttr
(
"act_type"
,
std
::
string
(
"relu"
));
FusionElementwiseActivationOp
fuse_op
(
"fusion_elementwise_add_activation"
);
fuse_op
.
SetValidPlaces
({
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
fuse_op
.
Attach
(
desc
,
&
scope
);
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/op_params.h
浏览文件 @
de2a72a4
...
...
@@ -219,6 +219,14 @@ struct ElementwiseGradParam {
int
axis
{
-
1
};
// for broadcasting.
};
struct
FusionElementwiseActivationParam
:
public
ElementwiseParam
{
std
::
string
act_type
;
};
struct
FusionElementwiseActivationGradParam
:
public
ElementwiseGradParam
{
std
::
string
act_type
;
};
/// ----------------------- activation operators ----------------------
struct
ActivationParam
{
const
lite
::
Tensor
*
X
{};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录