Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41602396
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
41602396
编写于
6月 13, 2019
作者:
S
sangoly
提交者:
GitHub
6月 13, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add conv_elementwise_add_relu_fuse_pass test=develop (#18079)
上级
234fb8f4
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
399 addition
and
8 deletion
+399
-8
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+7
-1
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
...luid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
+36
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
...fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
...lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
+153
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+9
-1
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
...d/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
+104
-0
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
...id/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
+38
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+1
-0
paddle/fluid/lite/operators/conv_op.h
paddle/fluid/lite/operators/conv_op.h
+19
-5
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+0
-1
未找到文件。
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
41602396
...
...
@@ -6,6 +6,7 @@ cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
...
...
@@ -15,7 +16,7 @@ cc_library(mir_passes
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite
mir_fusers
)
DEPS mir_pass types_lite context_lite
${
mir_fusers
}
)
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
@@ -74,3 +75,8 @@ lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_fc_model.tar.gz"
)
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
lite_cc_test
(
test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
)
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc
0 → 100644
浏览文件 @
41602396
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
ConvElementwiseAddReLUFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
ConvElementwiseAddReLUFuser
fuser
;
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
,
paddle
::
lite
::
mir
::
ConvElementwiseAddReLUFusePass
);
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h
0 → 100644
浏览文件 @
41602396
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
ConvElementwiseAddReLUFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc
0 → 100644
浏览文件 @
41602396
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
std
::
unique_ptr
<
SSAGraph
>
BuildGraph
(
framework
::
ProgramDesc
*
program_desc
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
std
::
vector
<
Place
>&
valid_places
)
{
auto
*
main_block
=
program_desc
->
MutableBlock
(
0
);
auto
*
conv2d_1
=
main_block
->
AppendOp
();
auto
*
conv2d_2
=
main_block
->
AppendOp
();
auto
*
add_1
=
main_block
->
AppendOp
();
auto
*
add_2
=
main_block
->
AppendOp
();
auto
*
relu_1
=
main_block
->
AppendOp
();
auto
*
relu_2
=
main_block
->
AppendOp
();
main_block
->
Var
(
"input_1"
);
main_block
->
Var
(
"input_2"
);
main_block
->
Var
(
"filter_1"
);
main_block
->
Var
(
"filter_2"
);
main_block
->
Var
(
"conv2d_1_out"
);
main_block
->
Var
(
"conv2d_2_out"
);
main_block
->
Var
(
"bias_1"
);
main_block
->
Var
(
"add_1_out"
);
main_block
->
Var
(
"add_2_out"
);
main_block
->
Var
(
"relu_1_out"
);
main_block
->
Var
(
"out"
);
scope
->
Var
(
"input_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"input_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"filter_2"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"conv2d_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"bias_1"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"add_2_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"relu_1_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
conv2d_1
->
SetType
(
"conv2d"
);
conv2d_1
->
SetInput
(
"Input"
,
{
"input_1"
});
conv2d_1
->
SetInput
(
"Filter"
,
{
"filter_1"
});
conv2d_1
->
SetOutput
(
"Output"
,
{
"conv2d_1_out"
});
conv2d_1
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_1
->
SetAttr
(
"groups"
,
1
);
conv2d_1
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_1
->
SetAttr
(
"fuse_relu"
,
false
);
add_1
->
SetType
(
"elementwise_add"
);
add_1
->
SetInput
(
"X"
,
{
"conv2d_1_out"
});
add_1
->
SetInput
(
"Y"
,
{
"bias_1"
});
add_1
->
SetOutput
(
"Out"
,
{
"add_1_out"
});
add_1
->
SetAttr
(
"axis"
,
1
);
relu_1
->
SetType
(
"relu"
);
relu_1
->
SetInput
(
"Input"
,
{
"add_1_out"
});
relu_1
->
SetOutput
(
"Out"
,
{
"relu_1_out"
});
conv2d_2
->
SetType
(
"conv2d"
);
conv2d_2
->
SetInput
(
"Input"
,
{
"input_2"
});
conv2d_2
->
SetInput
(
"Filter"
,
{
"filter_2"
});
conv2d_2
->
SetOutput
(
"Output"
,
{
"conv2d_2_out"
});
conv2d_2
->
SetAttr
(
"strides"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"paddings"
,
std
::
vector
<
int
>
({
0
,
0
}));
conv2d_2
->
SetAttr
(
"groups"
,
1
);
conv2d_2
->
SetAttr
(
"dilations"
,
std
::
vector
<
int
>
({
1
,
1
}));
conv2d_2
->
SetAttr
(
"fuse_relu"
,
false
);
add_2
->
SetType
(
"elementwise_add"
);
add_2
->
SetInput
(
"X"
,
{
"conv2d_2_out"
});
add_2
->
SetInput
(
"Y"
,
{
"relu_1_out"
});
add_2
->
SetOutput
(
"Out"
,
{
"add_2_out"
});
add_2
->
SetAttr
(
"axis"
,
1
);
relu_2
->
SetType
(
"relu"
);
relu_2
->
SetInput
(
"Input"
,
{
"add_2_out"
});
relu_2
->
SetOutput
(
"Out"
,
{
"out"
});
program_desc
->
Flush
();
lite
::
Program
program
(
*
program_desc
->
Proto
(),
scope
,
valid_places
);
auto
graph
=
std
::
unique_ptr
<
SSAGraph
>
(
new
SSAGraph
());
graph
->
Build
(
program
,
valid_places
);
return
graph
;
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
11UL
/*vars*/
+
6UL
/*ops*/
+
2UL
/*feed op + fetch op*/
);
Visualize
(
graph
.
get
());
}
TEST
(
conv_elementwise_add_relu_fuse_pass
,
fuse_test_op
)
{
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
Visualize
(
graph
.
get
());
const
int
num_nodes
=
graph
->
nodes
().
size
();
auto
*
fuser
=
new
ConvElementwiseAddReLUFusePass
;
fuser
->
Apply
(
graph
);
Visualize
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
5UL
*
2
/*nodes removed */
+
1UL
*
2
/* fused fc node*/
);
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
conv2d
);
USE_LITE_OP
(
depthwise_conv2d
);
USE_LITE_OP
(
relu
);
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
浏览文件 @
41602396
cc_library
(
mir_fusers
cc_library
(
fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
cc_library
(
conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api
)
set
(
mir_fusers
fuse_fc
conv_elementwise_add_relu
CACHE INTERNAL
"fusers"
)
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc
0 → 100644
浏览文件 @
41602396
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
ConvElementwiseAddReLUFuser
::
BuildPattern
()
{
// create input nodes.
auto
*
input
=
VarNode
(
"input"
);
auto
*
filter
=
VarNode
(
"filter"
);
auto
*
bias
=
VarNode
(
"bias"
);
// create op nodes
auto
*
conv2d
=
OpNode
(
"conv2d"
,
"conv2d"
);
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
);
auto
*
relu
=
OpNode
(
"relu"
,
"relu"
);
// create intermediate nodes
auto
*
conv2d_out
=
VarNode
(
"conv2d_out"
);
auto
*
add_out
=
VarNode
(
"add_out"
);
// create output node
auto
*
out
=
VarNode
(
"output"
);
// create topology.
std
::
vector
<
PMNode
*>
conv2d_inputs
{
filter
,
input
};
std
::
vector
<
PMNode
*>
add_inputs
{
conv2d_out
,
bias
};
conv2d_inputs
>>
*
conv2d
>>
*
conv2d_out
;
add_inputs
>>
*
add
>>
*
add_out
;
*
add_out
>>
*
relu
>>
*
out
;
// Some op specialities.
conv2d_out
->
AsIntermediate
();
add_out
->
AsIntermediate
();
conv2d
->
AsIntermediate
();
add
->
AsIntermediate
();
relu
->
AsIntermediate
();
}
void
ConvElementwiseAddReLUFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
conv_op
=
LiteOpRegistry
::
Global
().
Create
(
"conv2d"
);
auto
conv_old
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op
;
auto
*
scope
=
conv_old
->
scope
();
auto
&
valid_places
=
conv_old
->
valid_places
();
conv_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
conv_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"input"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"filter"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"bias"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"output"
));
}
cpp
::
OpDesc
ConvElementwiseAddReLUFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
auto
*
desc
=
matched
.
at
(
"conv2d"
)
->
stmt
()
->
op_info
();
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"conv2d"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"input"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Filter"
,
{
matched
.
at
(
"filter"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"bias"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
matched
.
at
(
"output"
)
->
arg
()
->
name
});
// Other inputs. See operators/conv_op.h
std
::
vector
<
std
::
string
>
input_arg_names
=
desc
->
InputArgumentNames
();
for
(
auto
name
:
input_arg_names
)
LOG
(
INFO
)
<<
name
;
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
op_desc
.
SetInput
(
"ResidualData"
,
desc
->
Input
(
"ResidualData"
));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc
.
SetAttr
(
"strides"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
));
op_desc
.
SetAttr
(
"paddings"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
));
op_desc
.
SetAttr
(
"groups"
,
desc
->
GetAttr
<
int
>
(
"groups"
));
op_desc
.
SetAttr
(
"dilations"
,
desc
->
GetAttr
<
std
::
vector
<
int
>>
(
"dilations"
));
op_desc
.
SetAttr
(
"fuse_relu"
,
true
);
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h
0 → 100644
浏览文件 @
41602396
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
ConvElementwiseAddReLUFuser
:
public
FuseBase
{
public:
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
41602396
...
...
@@ -23,6 +23,7 @@ namespace mir {} // namespace mir
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_conv_elementwise_add_act_fuse_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
type_target_transform_pass
);
...
...
paddle/fluid/lite/operators/conv_op.h
浏览文件 @
41602396
...
...
@@ -64,17 +64,31 @@ class ConvOpLite : public OpLite {
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
X
=
op_desc
.
Input
(
"Input"
).
front
();
auto
Filter
=
op_desc
.
Input
(
"Filter"
).
front
();
auto
Bias
=
op_desc
.
Input
(
"Bias"
).
front
();
// auto ResidualData = op_desc.Input("ResidualData");
auto
Out
=
op_desc
.
Output
(
"Output"
).
front
();
param_
.
x
=
scope
->
FindVar
(
X
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
filter
=
scope
->
FindVar
(
Filter
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
bias
=
scope
->
FindVar
(
Bias
)
->
GetMutable
<
lite
::
Tensor
>
();
// param_.residualData =
// scope->FindVar(ResidualData)->GetMutable<lite::Tensor>();
param_
.
output
=
scope
->
FindVar
(
Out
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
std
::
string
>
input_arg_names
=
op_desc
.
InputArgumentNames
();
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"Bias"
)
!=
input_arg_names
.
end
())
{
auto
bias_var
=
scope
->
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
());
if
(
bias_var
!=
nullptr
)
{
param_
.
bias
=
const_cast
<
lite
::
Tensor
*>
(
&
(
bias_var
->
Get
<
lite
::
Tensor
>
()));
}
}
if
(
std
::
find
(
input_arg_names
.
begin
(),
input_arg_names
.
end
(),
"ResidualData"
)
!=
input_arg_names
.
end
())
{
auto
residual_data_var
=
scope
->
FindVar
(
op_desc
.
Input
(
"ResidualData"
).
front
());
if
(
residual_data_var
!=
nullptr
)
{
param_
.
residualData
=
const_cast
<
lite
::
Tensor
*>
(
&
(
residual_data_var
->
Get
<
lite
::
Tensor
>
()));
}
}
param_
.
strides
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"strides"
);
param_
.
paddings
=
op_desc
.
GetAttr
<
std
::
vector
<
int
>>
(
"paddings"
);
param_
.
groups
=
op_desc
.
GetAttr
<
int
>
(
"groups"
);
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
41602396
...
...
@@ -37,7 +37,6 @@ bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope
->
FindVar
(
opdesc
.
Output
(
"Out"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
CHECK
(
param_
.
input
);
CHECK
(
param_
.
output
);
kernel_
->
SetParam
(
param_
);
return
true
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录