Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
234fb8f4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
234fb8f4
编写于
6月 13, 2019
作者:
Z
Zhen Wang
提交者:
Yan Chunwei
6月 13, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Fc fuse pass (#17994)
上级
016445c4
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
430 addition
and
77 deletion
+430
-77
paddle/fluid/lite/CMakeLists.txt
paddle/fluid/lite/CMakeLists.txt
+3
-2
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+2
-4
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+4
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+18
-3
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
+34
-0
paddle/fluid/lite/core/mir/fc_fuse_pass.h
paddle/fluid/lite/core/mir/fc_fuse_pass.h
+32
-0
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
+112
-0
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
+3
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
+78
-0
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
+38
-0
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+1
-0
paddle/fluid/lite/core/mir/pattern_matcher.cc
paddle/fluid/lite/core/mir/pattern_matcher.cc
+42
-1
paddle/fluid/lite/core/mir/pattern_matcher.h
paddle/fluid/lite/core/mir/pattern_matcher.h
+12
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
+0
-1
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
+10
-13
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+1
-0
paddle/fluid/lite/core/profile/basic_profiler.h
paddle/fluid/lite/core/profile/basic_profiler.h
+1
-1
paddle/fluid/lite/kernels/host/CMakeLists.txt
paddle/fluid/lite/kernels/host/CMakeLists.txt
+1
-3
paddle/fluid/lite/kernels/x86/CMakeLists.txt
paddle/fluid/lite/kernels/x86/CMakeLists.txt
+3
-5
paddle/fluid/lite/kernels/x86/fc_compute.cc
paddle/fluid/lite/kernels/x86/fc_compute.cc
+17
-24
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+18
-18
未找到文件。
paddle/fluid/lite/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
...
@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message
(
STATUS
"LITE_WITH_PROFILE:
\t
${
LITE_WITH_PROFILE
}
"
)
message
(
STATUS
"LITE_WITH_PROFILE:
\t
${
LITE_WITH_PROFILE
}
"
)
set
(
LITE_MODEL_DIR
"
${
THIRD_PARTY_PATH
}
/install"
)
set
(
LITE_MODEL_DIR
"
${
THIRD_PARTY_PATH
}
/install"
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
function
(
lite_download_and_uncompress INSTALL_DIR URL FILENAME
)
function
(
lite_download_and_uncompress INSTALL_DIR URL FILENAME
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
message
(
STATUS
"Download inference test stuff from
${
URL
}
/
${
FILENAME
}
"
)
...
@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
...
@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file
(
APPEND
${
offline_test_registry_file
}
"
${
TARGET
}
\n
"
)
file
(
APPEND
${
offline_test_registry_file
}
"
${
TARGET
}
\n
"
)
endfunction
()
endfunction
()
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
core
)
add_subdirectory
(
core
)
add_subdirectory
(
x86
)
add_subdirectory
(
x86
)
add_subdirectory
(
arm
)
add_subdirectory
(
arm
)
add_subdirectory
(
host
)
add_subdirectory
(
host
)
add_subdirectory
(
cuda
)
add_subdirectory
(
cuda
)
add_subdirectory
(
operators
)
add_subdirectory
(
kernels
)
add_subdirectory
(
model_parser
)
add_subdirectory
(
model_parser
)
add_subdirectory
(
utils
)
add_subdirectory
(
utils
)
add_subdirectory
(
api
)
add_subdirectory
(
api
)
...
...
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
...
@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda
)
endif
()
endif
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS
${
cxx_api_lite_deps
}
${
ops_lite
}
program_lite
)
set
(
light_api_deps
set
(
light_api_deps
scope_lite target_wrapper_host model_parser_lite
)
scope_lite target_wrapper_host model_parser_lite
)
...
@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
...
@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message
(
STATUS
"get ARM kernels
${
arm_kernels
}
"
)
message
(
STATUS
"get ARM kernels
${
arm_kernels
}
"
)
include
(
ExternalProject
)
include
(
ExternalProject
)
set
(
LITE_URL
"http://paddle-inference-dist.bj.bcebos.com"
CACHE STRING
"inference download url"
)
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
set
(
LITE_DEMO_INSTALL_DIR
"
${
THIRD_PARTY_PATH
}
/inference_demo"
CACHE STRING
"A path setting inference demo download directories."
)
"A path setting inference demo download directories."
)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
if
((
NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
AND WITH_TESTING
)
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
lite_cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite m
odel_parser_lite target_wrapper_host
DEPS cxx_api_lite m
ir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_naive_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_naive_model_opt SERIAL
)
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -30,7 +30,10 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
...
@@ -30,7 +30,10 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
cc_library
(
type_system SRCS type_system.cc DEPS
${
tensor_lite
}
target_wrapper_lite
)
lite_cc_library
(
program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto
)
lite_cc_library
(
program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
add_subdirectory
(
mir
)
add_subdirectory
(
mir
)
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -3,8 +3,10 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
...
@@ -3,8 +3,10 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
add_subdirectory
(
fusion
)
cc_library
(
mir_passes
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
SRCS fc_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
io_copy_kernel_pick_pass.cc
...
@@ -13,7 +15,7 @@ cc_library(mir_passes
...
@@ -13,7 +15,7 @@ cc_library(mir_passes
argument_type_display_pass.cc
argument_type_display_pass.cc
demo_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite
)
DEPS mir_pass types_lite context_lite
mir_fusers
)
# for mobile, unnecessary to compile the following testings.
# for mobile, unnecessary to compile the following testings.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
@@ -53,9 +55,22 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
...
@@ -53,9 +55,22 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library
(
pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite
)
lite_cc_library
(
pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite
)
# TODO(wz) replace framework/proto to lite proto.
# TODO(wz) replace framework/proto to lite proto.
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
if
(
NOT
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test
(
test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
lite_cc_test
(
test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite
${
ops_lite
}
)
mir_passes compatible_pb_lite program_lite
${
ops_lite
}
)
endif
()
endif
()
message
(
STATUS
"----> Ops lite:
${
ops_lite
}
"
)
message
(
STATUS
"----> Host kernels:
${
host_kernels
}
"
)
message
(
STATUS
"----> X86 kernels:
${
x86_kernels
}
"
)
lite_cc_test
(
test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${
ops_lite
}
${
host_kernels
}
${
x86_kernels
}
ARGS --model_dir=
${
LITE_MODEL_DIR
}
/lite_fc_model
--optimized_model=
${
LITE_MODEL_DIR
}
/lite_fc_model_opt SERIAL
)
lite_download_and_uncompress
(
${
LITE_MODEL_DIR
}
${
LITE_URL
}
"lite_fc_model.tar.gz"
)
add_dependencies
(
test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz
)
paddle/fluid/lite/core/mir/fc_fuse_pass.cc
0 → 100644
浏览文件 @
234fb8f4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
FcFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
fusion
::
FcFuser
fuser
;
fuser
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_fc_fuse_pass
,
paddle
::
lite
::
mir
::
FcFusePass
);
paddle/fluid/lite/core/mir/fc_fuse_pass.h
0 → 100644
浏览文件 @
234fb8f4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
FcFusePass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc
0 → 100644
浏览文件 @
234fb8f4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string
(
model_dir
,
""
,
""
);
DEFINE_string
(
optimized_model
,
""
,
""
);
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
fc_fuse_pass
,
fuse_test
)
{
lite
::
ExecutorLite
predictor
;
#ifndef LITE_WITH_CUDA
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
#else
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kNCHW
)},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)},
});
#endif
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
// origin cuda
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
input_tensor
->
Resize
(
DDim
(
std
::
vector
<
DDim
::
value_type
>
({
100
,
100
})));
auto
*
data
=
input_tensor
->
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
100
*
100
;
i
++
)
{
data
[
i
]
=
i
;
}
predictor
.
Run
();
auto
*
out
=
predictor
.
GetOutput
(
0
);
LOG
(
INFO
)
<<
out
<<
" memory size "
<<
out
->
data_size
();
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
0
];
LOG
(
INFO
)
<<
"out "
<<
out
->
data
<
float
>
()[
1
];
LOG
(
INFO
)
<<
"dims "
<<
out
->
dims
();
EXPECT_NEAR
(
out
->
data
<
float
>
()[
0
],
38.120617
f
,
1e-5
);
EXPECT_NEAR
(
out
->
data
<
float
>
()[
1
],
10.109812
f
,
1e-5
);
CHECK_EQ
(
out
->
dims
()[
0
],
100
);
CHECK_EQ
(
out
->
dims
()[
1
],
500
);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST
(
fc_fuse_pass
,
save_model_test
)
{
lite
::
ExecutorLite
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)},
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)}});
predictor
.
Build
(
FLAGS_model_dir
,
Place
{
TARGET
(
kX86
),
PRECISION
(
kFloat
)},
valid_places
);
LOG
(
INFO
)
<<
"Save optimized model to "
<<
FLAGS_optimized_model
;
predictor
.
SaveModel
(
FLAGS_optimized_model
);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
elementwise_add
);
USE_LITE_OP
(
elementwise_sub
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
io_copy
);
USE_LITE_OP
(
softmax
);
USE_LITE_OP
(
scale
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL
(
mul
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
fc
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_sub
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
elementwise_add
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
softmax
,
kX86
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kX86
,
kFloat
,
kNCHW
,
def
);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/core/mir/fusion/CMakeLists.txt
0 → 100644
浏览文件 @
234fb8f4
cc_library
(
mir_fusers
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api
)
paddle/fluid/lite/core/mir/fusion/fc_fuser.cc
0 → 100644
浏览文件 @
234fb8f4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
void
FcFuser
::
BuildPattern
()
{
// create nodes.
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
);
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
);
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
auto
*
add
=
OpNode
(
"add"
,
"elementwise_add"
);
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
mul_inputs
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
// Some op specialities.
mul_out
->
AsIntermediate
();
mul
->
AsIntermediate
();
add
->
AsIntermediate
();
}
void
FcFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
op_desc
=
GenOpDesc
(
matched
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
auto
mul
=
matched
.
at
(
"mul"
)
->
stmt
()
->
op
;
auto
*
scope
=
mul
->
scope
();
auto
&
valid_places
=
mul
->
valid_places
();
fc_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
fc_op
,
valid_places
);
IR_NODE_LINK_TO
(
matched
.
at
(
"W"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"x"
),
new_op_node
);
IR_NODE_LINK_TO
(
matched
.
at
(
"b"
),
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
matched
.
at
(
"Out"
));
}
cpp
::
OpDesc
FcFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"fc"
);
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"x"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"W"
,
{
matched
.
at
(
"W"
)
->
arg
()
->
name
});
op_desc
.
SetInput
(
"Bias"
,
{
matched
.
at
(
"b"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
matched
.
at
(
"Out"
)
->
arg
()
->
name
});
op_desc
.
SetAttr
(
"in_num_col_dims"
,
matched
.
at
(
"mul"
)
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"x_num_col_dims"
));
return
op_desc
;
}
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/fusion/fc_fuser.h
0 → 100644
浏览文件 @
234fb8f4
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
namespace
fusion
{
class
FcFuser
:
public
FuseBase
{
public:
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
};
}
// namespace fusion
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
234fb8f4
...
@@ -22,6 +22,7 @@ namespace mir {} // namespace mir
...
@@ -22,6 +22,7 @@ namespace mir {} // namespace mir
}
// namespace paddle
}
// namespace paddle
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
type_target_transform_pass
);
USE_MIR_PASS
(
type_target_transform_pass
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher.cc
浏览文件 @
234fb8f4
...
@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
...
@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return
*
this
;
return
*
this
;
}
}
void
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>
&
others
,
PMNode
&
me
)
{
for
(
auto
*
o
:
others
)
{
for
(
auto
*
o
:
others
)
{
*
o
>>
me
;
*
o
>>
me
;
}
}
return
me
;
}
}
PMNode
*
PMPattern
::
NewNode
(
const
std
::
string
&
name
)
{
PMNode
*
PMPattern
::
NewNode
(
const
std
::
string
&
name
)
{
...
@@ -422,6 +423,46 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
...
@@ -422,6 +423,46 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return
this
;
return
this
;
}
}
PMNode
*
PMNode
::
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
)
{
assert_is_var
();
assert_is_op_nth_input
(
op_type
,
argument
,
0
);
return
this
;
}
PMNode
*
PMNode
::
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
)
{
assert_is_var
();
assert_is_op_input
(
op_type
);
asserts_
.
emplace_back
([
=
](
const
Node
*
x
)
{
for
(
auto
*
op
:
x
->
outlinks
)
{
if
(
op
->
IsStmt
()
&&
op
->
stmt
()
->
op_info
()
->
Type
()
==
op_type
&&
IsNthInput
(
*
x
,
*
op
,
argument
,
nth
))
return
true
;
}
return
false
;
});
return
this
;
}
bool
IsNthInput
(
const
Node
&
var
,
const
Node
&
op
,
const
std
::
string
&
argument
,
int
nth
)
{
CHECK
(
var
.
IsArg
());
CHECK
(
op
.
IsStmt
());
if
(
!
HasInput
(
op
,
argument
)
||
static_cast
<
int
>
(
op
.
stmt
()
->
op_info
()
->
Input
(
argument
).
size
())
<=
nth
)
return
false
;
return
var
.
arg
()
->
name
==
op
.
stmt
()
->
op_info
()
->
Input
(
argument
)[
nth
];
}
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
)
{
CHECK
(
op
.
IsStmt
());
auto
const
&
names
=
op
.
stmt
()
->
op_info
()
->
input_argnames
();
if
(
std
::
find
(
names
.
begin
(),
names
.
end
(),
argument
)
==
names
.
end
())
return
false
;
return
true
;
}
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>
&
nodes
)
{
const
std
::
unordered_set
<
const
Node
*>
&
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
for
(
auto
*
node
:
nodes
)
{
...
...
paddle/fluid/lite/core/mir/pattern_matcher.h
浏览文件 @
234fb8f4
...
@@ -62,7 +62,7 @@ struct PMNode {
...
@@ -62,7 +62,7 @@ struct PMNode {
PMNode
&
operator
>>
(
PMNode
&
right
);
PMNode
&
operator
>>
(
PMNode
&
right
);
// Link many nodes to this node.
// Link many nodes to this node.
friend
void
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
friend
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
others
,
PMNode
&
me
);
// Link this to many other nodes.
// Link this to many other nodes.
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
nodes
);
PMNode
&
operator
>>
(
std
::
vector
<
PMNode
*>&
nodes
);
...
@@ -127,6 +127,10 @@ struct PMNode {
...
@@ -127,6 +127,10 @@ struct PMNode {
PMNode
*
assert_is_persistable_var
();
PMNode
*
assert_is_persistable_var
();
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_output
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
);
PMNode
*
assert_is_op_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
);
PMNode
*
assert_is_op_nth_input
(
const
std
::
string
&
op_type
,
const
std
::
string
&
argument
,
int
nth
);
template
<
typename
T
>
template
<
typename
T
>
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
PMNode
*
assert_op_attr
(
const
std
::
string
&
attr_name
,
const
T
&
attr
)
{
...
@@ -297,6 +301,13 @@ class PatternMatcher {
...
@@ -297,6 +301,13 @@ class PatternMatcher {
std
::
unordered_map
<
const
PMNode
*
,
std
::
unordered_set
<
Node
*>>
pmnodes2nodes_
;
std
::
unordered_map
<
const
PMNode
*
,
std
::
unordered_set
<
Node
*>>
pmnodes2nodes_
;
};
};
// Check whether a var node is a op node's nth input.
bool
IsNthInput
(
const
Node
&
var
,
const
Node
&
op
,
const
std
::
string
&
argument
,
int
nth
);
// Check whether the op node has input of given name.
bool
HasInput
(
const
Node
&
op
,
const
std
::
string
&
argument
);
// Graph safely remove some nodes, will automatically clean up the edges.
// Graph safely remove some nodes, will automatically clean up the edges.
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
void
GraphSafeRemoveNodes
(
SSAGraph
*
graph
,
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
const
std
::
unordered_set
<
const
Node
*>&
nodes
);
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api.h
浏览文件 @
234fb8f4
...
@@ -64,7 +64,6 @@ class FuseBase {
...
@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate
// Delete nodes that are marked as Intermediate
void
DeleteInterNodes
(
SSAGraph
*
graph
);
void
DeleteInterNodes
(
SSAGraph
*
graph
);
private:
PMNode
*
GetOrCreateNode
(
const
std
::
string
&
key
);
PMNode
*
GetOrCreateNode
(
const
std
::
string
&
key
);
protected:
protected:
...
...
paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc
浏览文件 @
234fb8f4
...
@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
...
@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public:
public:
void
BuildPattern
()
override
{
void
BuildPattern
()
override
{
// create nodes.
// create nodes.
auto
*
x
=
VarNode
(
"x"
);
auto
*
x
=
VarNode
(
"x"
)
->
assert_is_op_input
(
"mul"
,
"X"
)
;
auto
*
W
=
VarNode
(
"W"
);
auto
*
W
=
VarNode
(
"W"
)
->
assert_is_op_input
(
"mul"
,
"Y"
)
;
auto
*
b
=
VarNode
(
"b"
);
auto
*
b
=
VarNode
(
"b"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul
=
OpNode
(
"mul"
,
"mul"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
auto
*
mul_out
=
VarNode
(
"mul_out"
);
...
@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
...
@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto
*
Out
=
VarNode
(
"Out"
);
auto
*
Out
=
VarNode
(
"Out"
);
// create topology.
// create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out;
std
::
vector
<
PMNode
*>
mul_inputs
{
W
,
x
};
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out;
std
::
vector
<
PMNode
*>
add_inputs
{
mul_out
,
b
};
*
W
>>
*
mul
;
mul_inputs
>>
*
mul
>>
*
mul_out
;
*
x
>>
*
mul
>>
*
mul_out
;
add_inputs
>>
*
add
>>
*
Out
;
*
b
>>
*
add
;
*
mul_out
>>
*
add
>>
*
Out
;
// Some op specialities.
// Some op specialities.
mul_out
->
AsIntermediate
();
mul_out
->
AsIntermediate
();
...
@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
...
@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block
->
Var
(
"mul_out"
);
main_block
->
Var
(
"mul_out"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"w"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out"
);
main_block
->
Var
(
"out1"
);
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"b"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"mul_out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"w"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
lite
::
Tensor
>
();
scope
->
Var
(
"out1"
)
->
GetMutable
<
lite
::
Tensor
>
();
mul
->
SetInput
(
"X"
,
{
"x"
});
mul
->
SetInput
(
"X"
,
{
"x"
});
mul
->
SetInput
(
"Y"
,
{
"w"
});
mul
->
SetInput
(
"Y"
,
{
"w"
});
...
@@ -122,18 +118,18 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
...
@@ -122,18 +118,18 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return
graph
;
return
graph
;
}
}
TEST
(
pattern_matcher
2
,
graph_test
)
{
TEST
(
pattern_matcher
_high_api
,
graph_test
)
{
framework
::
ProgramDesc
program_desc
;
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
auto
graph
=
BuildGraph
(
&
program_desc
,
scope
,
places
);
ASSERT_EQ
(
graph
->
nodes
().
size
(),
ASSERT_EQ
(
graph
->
nodes
().
size
(),
8
UL
/*real nodes*/
+
2UL
/*feed op + fetch op*/
);
7
UL
/*real nodes*/
+
2UL
/*feed op + fetch op*/
);
Visualize
(
graph
.
get
());
Visualize
(
graph
.
get
());
}
}
TEST
(
pattern_matcher
2
,
test
)
{
TEST
(
pattern_matcher
_high_api
,
fuse_
test
)
{
framework
::
ProgramDesc
program_desc
;
framework
::
ProgramDesc
program_desc
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
...
@@ -143,6 +139,7 @@ TEST(pattern_matcher2, test) {
...
@@ -143,6 +139,7 @@ TEST(pattern_matcher2, test) {
fuser
(
graph
.
get
());
fuser
(
graph
.
get
());
ASSERT_EQ
(
graph
->
nodes
().
size
(),
ASSERT_EQ
(
graph
->
nodes
().
size
(),
num_nodes
-
3UL
/*nodes removed */
+
1UL
/* fused fc node*/
);
num_nodes
-
3UL
/*nodes removed */
+
1UL
/* fused fc node*/
);
Visualize
(
graph
.
get
());
}
}
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
234fb8f4
...
@@ -49,6 +49,7 @@ class Optimizer {
...
@@ -49,6 +49,7 @@ class Optimizer {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if
(
passes
.
empty
())
{
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"lite_fc_fuse_pass"
,
//
"static_kernel_pick_pass"
,
//
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"argument_type_display_pass"
,
//
...
...
paddle/fluid/lite/core/profile/basic_profiler.h
浏览文件 @
234fb8f4
...
@@ -152,8 +152,8 @@ class BasicProfiler {
...
@@ -152,8 +152,8 @@ class BasicProfiler {
}
}
record_t
*
mutable_record
(
int
id
)
{
record_t
*
mutable_record
(
int
id
)
{
CHECK_LT
(
id
,
records_
.
size
());
CHECK_GE
(
id
,
0
);
CHECK_GE
(
id
,
0
);
CHECK_LT
(
static_cast
<
size_t
>
(
id
),
records_
.
size
());
return
&
records_
[
id
];
return
&
records_
[
id
];
}
}
...
...
paddle/fluid/lite/kernels/host/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -10,6 +10,4 @@ set(host_kernels
...
@@ -10,6 +10,4 @@ set(host_kernels
feed_compute_host
feed_compute_host
fetch_compute_host
fetch_compute_host
reshape_compute_host
reshape_compute_host
)
CACHE INTERNAL
"host kernels"
)
set
(
host_kernels
"
${
host_kernels
}
"
CACHE GLOBAL
"host kernels"
)
paddle/fluid/lite/kernels/x86/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -32,6 +32,4 @@ set(x86_kernels
...
@@ -32,6 +32,4 @@ set(x86_kernels
concat_compute_x86
concat_compute_x86
conv_compute_x86
conv_compute_x86
pool_compute_x86
pool_compute_x86
)
CACHE INTERNAL
"x86 kernels"
)
set
(
x86_kernels
"
${
x86_kernels
}
"
CACHE INTERNAL
"x86 kernels"
)
paddle/fluid/lite/kernels/x86/fc_compute.cc
浏览文件 @
234fb8f4
...
@@ -27,8 +27,8 @@ namespace kernels {
...
@@ -27,8 +27,8 @@ namespace kernels {
namespace
x86
{
namespace
x86
{
template
<
typename
T
>
template
<
typename
T
>
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
w
,
int
x_h
,
//
void
fc_compute_eigen
(
const
T
*
x
,
int
x_
h
,
int
x_w
,
//
const
T
*
w
,
int
w_
w
,
int
w_h
,
//
const
T
*
w
,
int
w_
h
,
int
w_w
,
//
const
T
*
b
,
//
const
T
*
b
,
//
T
*
out
)
{
T
*
out
)
{
using
matrix_t
=
using
matrix_t
=
...
@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
...
@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen
::
Map
<
const
matrix_t
>
X
(
x
,
x_h
,
x_w
);
Eigen
::
Map
<
const
matrix_t
>
X
(
x
,
x_h
,
x_w
);
Eigen
::
Map
<
const
matrix_t
>
W
(
w
,
w_h
,
w_w
);
Eigen
::
Map
<
const
matrix_t
>
W
(
w
,
w_h
,
w_w
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
h
);
Eigen
::
Map
<
matrix_t
>
Out
(
out
,
x_h
,
w_
w
);
Out
=
X
*
W
.
transpose
()
;
Out
=
X
*
W
;
if
(
b
)
{
if
(
b
)
{
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
h
);
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
1
>>
B
(
b
,
w_
w
);
Out
=
Out
.
array
().
rowwise
()
+
B
.
transpose
().
array
();
Out
=
Out
.
array
().
rowwise
()
+
B
.
transpose
().
array
();
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__attribute__
((
optimize
(
"unroll-loops"
)))
//
void
fc_compute_naive
(
const
T
*
x
,
int
x_h
,
int
x_w
,
//
T
dot
(
const
T
*
x
,
const
T
*
y
,
int
dim
)
{
const
T
*
w
,
int
w_h
,
int
w_w
,
//
T
out
{};
for
(
int
i
=
0
;
i
<
dim
;
i
++
)
{
out
+=
x
[
i
]
*
y
[
i
];
}
return
out
;
}
template
<
typename
T
>
void
fc_compute_naive
(
const
T
*
x
,
int
x_w
,
int
x_h
,
//
const
T
*
w
,
int
w_w
,
int
w_h
,
//
const
T
*
b
,
//
const
T
*
b
,
//
T
*
out
)
{
T
*
out
)
{
CHECK_EQ
(
x_w
,
w_
w
);
CHECK_EQ
(
x_w
,
w_
h
);
// out shape: (x_h, w_w)
// out shape: (x_h, w_w)
memset
(
out
,
0
,
x_h
*
w_h
*
sizeof
(
T
));
memset
(
out
,
0
,
x_h
*
w_w
*
sizeof
(
T
));
for
(
int
i
=
0
;
i
<
x_h
;
i
++
)
{
for
(
int
r
=
0
;
r
<
x_h
;
r
++
)
{
for
(
int
j
=
0
;
j
<
w_w
;
j
++
)
{
for
(
int
c
=
0
;
c
<
w_h
;
c
++
)
{
T
tmp
=
static_cast
<
T
>
(
0
);
out
[
r
*
w_h
+
c
]
=
dot
(
&
x
[
r
*
x_w
],
&
w
[
c
*
w_w
],
w_w
)
+
b
[
c
];
for
(
int
k
=
0
;
k
<
x_w
;
k
++
)
{
tmp
+=
x
[
i
*
x_w
+
k
]
*
w
[
k
*
w_w
+
j
];
}
out
[
i
*
w_w
+
j
]
=
tmp
+
b
[
j
];
}
}
}
}
}
}
...
@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
...
@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
Slice
(
param
.
in_num_col_dims
,
param
.
input
->
dims
().
size
())
.
production
(),
.
production
(),
param
.
w
->
data
<
T
>
(),
// w
param
.
w
->
data
<
T
>
(),
// w
param
.
w
->
dims
()[
1
],
// w_w
param
.
w
->
dims
()[
0
],
// w_h
param
.
w
->
dims
()[
0
],
// w_h
param
.
w
->
dims
()[
1
],
// w_w
param
.
bias
->
data
<
T
>
(),
// b
param
.
bias
->
data
<
T
>
(),
// b
param
.
output
->
mutable_data
<
T
>
());
param
.
output
->
mutable_data
<
T
>
());
}
}
...
...
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
234fb8f4
...
@@ -38,7 +38,7 @@ set(ops_lite
...
@@ -38,7 +38,7 @@ set(ops_lite
concat_op_lite
concat_op_lite
conv_op_lite
conv_op_lite
pool_op_lite
pool_op_lite
PARENT_SCOPE
)
CACHE INTERNAL
"ops lite"
)
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
lite_cc_test
(
test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite
DEPS fc_op_lite memory_lite
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录