Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
12db9f3c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
12db9f3c
编写于
4月 22, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make the predictor works from a faked model
上级
610ce3ae
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
315 addition
and
62 deletion
+315
-62
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+1
-1
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+23
-2
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+13
-23
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+1
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+2
-2
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+23
-4
paddle/fluid/lite/core/op_executor_test.cc
paddle/fluid/lite/core/op_executor_test.cc
+7
-11
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-0
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+16
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+11
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+21
-7
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+39
-0
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+9
-0
paddle/fluid/lite/kernels/host/CMakeLists.txt
paddle/fluid/lite/kernels/host/CMakeLists.txt
+3
-1
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+4
-5
paddle/fluid/lite/kernels/host/fetch_compute.cc
paddle/fluid/lite/kernels/host/fetch_compute.cc
+50
-0
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+4
-0
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+2
-0
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-0
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+4
-0
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+61
-0
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+2
-0
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+1
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+10
-4
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-0
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+2
-0
未找到文件。
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
12db9f3c
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite
optimizer_lite
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
12db9f3c
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -30,7 +31,6 @@ class Predictor {
...
@@ -30,7 +31,6 @@ class Predictor {
void
Build
(
const
std
::
string
&
model_path
,
void
Build
(
const
std
::
string
&
model_path
,
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
scope_
.
get
())
<<
"duplicate build found"
;
framework
::
proto
::
ProgramDesc
prog
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
...
@@ -38,10 +38,31 @@ class Predictor {
...
@@ -38,10 +38,31 @@ class Predictor {
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Optimizer
optimizer
;
Optimizer
optimizer
;
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
.
GenRuntimeProgram
();
program_
=
optimizer
.
GenRuntimeProgram
();
}
}
// Get offset-th col of feed.
Tensor
*
GetInput
(
size_t
offset
)
{
auto
*
_feed_list
=
program_
->
exec_scope
()
->
FindVar
(
"feed"
);
CHECK
(
_feed_list
)
<<
"no feed variable in exec_scope"
;
auto
*
feed_list
=
_feed_list
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
if
(
offset
>=
feed_list
->
size
())
{
feed_list
->
resize
(
offset
+
1
);
}
return
&
feed_list
->
at
(
offset
);
}
const
Tensor
*
GetOutput
(
size_t
offset
)
{
auto
*
_fetch_list
=
program_
->
exec_scope
()
->
FindVar
(
"fetch"
);
CHECK
(
_fetch_list
)
<<
"no fatch variable in exec_scope"
;
auto
fetch_list
=
_fetch_list
->
Get
<
std
::
vector
<
Tensor
>>
();
CHECK_LT
(
offset
,
fetch_list
.
size
())
<<
"offset "
<<
offset
<<
" overflow"
;
return
&
fetch_list
.
at
(
offset
);
}
void
Run
()
{
program_
->
Run
();
}
void
Run
()
{
program_
->
Run
();
}
private:
private:
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
12db9f3c
...
@@ -14,36 +14,22 @@
...
@@ -14,36 +14,22 @@
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
TEST
(
CXXApi
,
raw
)
{
Scope
scope
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
"/home/chunwei/project2/models/model2"
,
&
scope
,
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
lite
::
Executor
executor
(
&
scope
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
x
=
scope
.
Var
(
"a"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
100
,
100
});
x
->
mutable_data
<
float
>
();
executor
.
PrepareWorkspace
(
prog_desc
);
executor
.
Build
(
prog_desc
);
executor
.
Run
();
}
TEST
(
CXXApi
,
test
)
{
TEST
(
CXXApi
,
test
)
{
lite
::
Predictor
predictor
;
lite
::
Predictor
predictor
;
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
*
x
=
predictor
.
GetInputTensor
(
"a"
);
x
->
Resize
({
100
,
200
});
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
x
->
mutable_data
<
float
>
();
input_tensor
->
Resize
({
100
,
100
});
input_tensor
->
mutable_data
<
float
>
();
predictor
.
Run
();
}
}
}
// namespace lite
}
// namespace lite
...
@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
...
@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
scale
);
USE_LITE_OP
(
scale
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
);
USE_LITE_OP
(
feed
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
);
USE_LITE_OP
(
fetch
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
def
);
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
12db9f3c
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
mir_passes
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_passes
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
SRCS static_kernel_pick_pass.cc
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
12db9f3c
...
@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst
.
place
,
inst
.
op_type
,
tmp
);
inst
.
place
,
inst
.
op_type
,
tmp
);
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
<<
" "
<<
inst
.
place
;
<<
" "
<<
inst
.
place
;
if
(
type
->
tensor_place
!=
in
st
.
place
)
{
if
(
type
->
tensor_place
!=
in
->
AsArgument
()
.
place
)
{
LOG
(
INFO
)
<<
"found IO unmatched tensor
"
;
LOG
(
INFO
)
<<
"found IO unmatched tensor
: "
<<
in
->
AsArgument
().
name
;
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
12db9f3c
...
@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
...
@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param program: the op program
// @param valid_places: the valid places user set for the system.
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
// create
inputs
// create
temporary nodes.
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
...
@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
...
@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
arguments_
[
name
]
=
&
new_node
;
arguments_
[
name
]
=
&
new_node
;
}
}
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
for
(
auto
&
op
:
program
.
ops
)
{
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
node_storage_
.
back
().
AsInstruct
(
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
,
op
->
op_info
());
for
(
auto
&
kernel
:
kernels
)
{
op
->
AttachKernel
(
kernel
.
get
());
}
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
,
op
->
op_info
());
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
arguments_
.
a
t
(
name
);
auto
*
arg
=
Argumen
t
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
}
...
@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
...
@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
MarkArgumentWeights
(
program
);
MarkArgumentWeights
(
program
);
}
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
// The inputs of the graph.
// The inputs of the graph.
...
...
paddle/fluid/lite/core/op_executor_test.cc
浏览文件 @
12db9f3c
...
@@ -20,12 +20,9 @@ namespace paddle {
...
@@ -20,12 +20,9 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
TEST
(
executor
,
test
)
{
TEST
(
executor
,
test
)
{
std
::
vector
<
OpLite
::
Place
>
valid_places
{
std
::
vector
<
Place
>
valid_places
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
OpLite
::
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
Scope
scope
;
auto
scope
=
std
::
make_shared
<
lite
::
Scope
>
();
Executor
executor
(
&
scope
,
valid_places
);
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
program
.
MutableBlock
(
0
)
->
Var
(
"x"
);
program
.
MutableBlock
(
0
)
->
Var
(
"x"
);
...
@@ -42,19 +39,18 @@ TEST(executor, test) {
...
@@ -42,19 +39,18 @@ TEST(executor, test) {
op_desc
.
SetAttr
(
"in_num_col_dims"
,
static_cast
<
int
>
(
1
));
op_desc
.
SetAttr
(
"in_num_col_dims"
,
static_cast
<
int
>
(
1
));
program
.
Flush
();
program
.
Flush
();
auto
*
w
=
scope
.
Var
(
"w"
)
->
GetMutable
<
Tensor
>
();
auto
*
w
=
scope
->
Var
(
"w"
)
->
GetMutable
<
Tensor
>
();
w
->
Resize
({
20
,
20
});
w
->
Resize
({
20
,
20
});
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
auto
*
x
=
scope
->
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
1
,
10
,
20
});
x
->
Resize
({
1
,
10
,
20
});
auto
*
bias
=
scope
.
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
auto
*
bias
=
scope
->
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
bias
->
Resize
({
1
,
20
});
bias
->
Resize
({
1
,
20
});
bias
->
mutable_data
<
float
>
();
bias
->
mutable_data
<
float
>
();
w
->
mutable_data
<
float
>
();
w
->
mutable_data
<
float
>
();
x
->
mutable_data
<
float
>
();
x
->
mutable_data
<
float
>
();
executor
.
PrepareWorkspace
(
program
);
lite
::
Executor
executor
(
program
,
scope
,
valid_places
);
executor
.
Build
(
program
);
executor
.
Run
();
executor
.
Run
();
}
}
...
@@ -62,4 +58,4 @@ TEST(executor, test) {
...
@@ -62,4 +58,4 @@ TEST(executor, test) {
}
// namespace paddle
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
paddle/fluid/lite/core/op_lite.h
浏览文件 @
12db9f3c
...
@@ -101,6 +101,9 @@ class OpLite : public Registry {
...
@@ -101,6 +101,9 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
lite
::
Scope
*
scope
)
=
0
;
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
...
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
12db9f3c
...
@@ -11,3 +11,19 @@
...
@@ -11,3 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace
paddle
{
namespace
lite
{
void
Optimizer
::
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
)
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
CHECK
(
pass
);
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/optimizer.h
浏览文件 @
12db9f3c
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -30,11 +31,14 @@ namespace lite {
...
@@ -30,11 +31,14 @@ namespace lite {
class
Optimizer
{
class
Optimizer
{
public:
public:
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
RunPasses
();
RunPasses
();
exec_scope_
=
program
.
exec_scope
;
}
}
// Generate a new program based on the mir graph.
// Generate a new program based on the mir graph.
...
@@ -42,7 +46,10 @@ class Optimizer {
...
@@ -42,7 +46,10 @@ class Optimizer {
std
::
unique_ptr
<
Program
>
res
;
std
::
unique_ptr
<
Program
>
res
;
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
"generate_program_pass"
);
return
pass
->
GenProgram
();
auto
program
=
pass
->
GenProgram
();
CHECK
(
exec_scope_
);
program
->
set_exec_scope
(
exec_scope_
);
return
program
;
}
}
// Generate C++ code which combines the inference program, model and weights.
// Generate C++ code which combines the inference program, model and weights.
...
@@ -54,6 +61,8 @@ class Optimizer {
...
@@ -54,6 +61,8 @@ class Optimizer {
}
}
protected:
protected:
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
// Run the default passes registered in the PassManager.
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
...
@@ -62,6 +71,7 @@ class Optimizer {
...
@@ -62,6 +71,7 @@ class Optimizer {
private:
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
lite
::
Scope
*
exec_scope_
{};
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
12db9f3c
...
@@ -35,23 +35,22 @@ struct Program {
...
@@ -35,23 +35,22 @@ struct Program {
std
::
list
<
std
::
shared_ptr
<
OpLite
>>
ops
;
std
::
list
<
std
::
shared_ptr
<
OpLite
>>
ops
;
// the scope to run the kernels, NOTE not the root scope.
// the scope to run the kernels, NOTE not the root scope.
std
::
shared_ptr
<
lite
::
Scope
>
scope
;
std
::
shared_ptr
<
lite
::
Scope
>
scope
;
std
::
vector
<
Place
>
valid_places
;
// Runtime scope.
// Runtime scope.
lite
::
Scope
*
exec_scope
{};
lite
::
Scope
*
exec_scope
{};
const
framework
::
ProgramDesc
desc
;
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
Program
(
const
framework
::
ProgramDesc
&
desc
,
Program
(
const
framework
::
ProgramDesc
&
desc
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
scope
=
root
;
:
scope
(
root
),
valid_places
(
valid_places
),
desc
(
desc
)
{
PrepareWorkspace
(
desc
);
PrepareWorkspace
(
desc
);
Build
(
desc
,
valid_places
);
Build
(
desc
,
valid_places
);
}
}
std
::
unique_ptr
<
Program
>
Clone
()
const
{
std
::
unique_ptr
<
Program
>
Clone
()
const
{
std
::
unique_ptr
<
Program
>
res
(
new
Program
(
scope
));
std
::
unique_ptr
<
Program
>
res
(
new
Program
(
desc
,
scope
,
valid_places
));
res
->
tmp_vars
=
tmp_vars
;
res
->
weights
=
weights
;
res
->
ops
=
ops
;
return
res
;
return
res
;
}
}
...
@@ -64,7 +63,7 @@ struct Program {
...
@@ -64,7 +63,7 @@ struct Program {
// Create operators.
// Create operators.
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
auto
op_type
=
op_desc
->
Type
();
auto
op_type
=
op_desc
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
//
if (op_type == "feed" || op_type == "fetch") continue;
LOG
(
INFO
)
<<
"create Op ["
<<
op_type
<<
"]"
;
LOG
(
INFO
)
<<
"create Op ["
<<
op_type
<<
"]"
;
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
// pick initial kernel
...
@@ -77,11 +76,22 @@ struct Program {
...
@@ -77,11 +76,22 @@ struct Program {
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
exec_scope
=
&
scope
->
NewScope
();
exec_scope
=
&
scope
->
NewScope
();
// Create Feed and Fetch var.
scope
->
Var
(
"feed"
)
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
scope
->
Var
(
"fetch"
)
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
tmp_vars
.
push_back
(
"feed"
);
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
if
(
!
var_desc
->
Persistable
())
{
if
(
!
var_desc
->
Persistable
())
{
LOG
(
INFO
)
<<
"get tmp var "
<<
var_desc
->
Name
();
tmp_vars
.
push_back
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
LOG
(
INFO
)
<<
"get weight var "
<<
var_desc
->
Name
();
weights
.
push_back
(
var_desc
->
Name
());
}
}
}
}
}
}
...
@@ -118,11 +128,15 @@ class RuntimeProgram {
...
@@ -118,11 +128,15 @@ class RuntimeProgram {
}
}
}
}
void
set_exec_scope
(
lite
::
Scope
*
x
)
{
exec_scope_
=
x
;
}
lite
::
Scope
*
exec_scope
()
{
return
exec_scope_
;
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
private:
private:
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
std
::
vector
<
Instruction
>
instructions_
;
std
::
vector
<
Instruction
>
instructions_
;
lite
::
Scope
*
exec_scope_
{};
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
12db9f3c
...
@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
...
@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
DataLayoutType
::
kNCHW
>
();
DataLayoutType
::
kNCHW
>
();
}
}
template
<
TargetType
Target
>
TensorListAnyTy
*
GetTensorListAnyTy
()
{
static
TensorListAnyTy
x
(
Target
);
return
&
x
;
}
template
<
TargetType
Target
>
TensorAnyTy
*
GetTensorAnyTy
()
{
static
TensorAnyTy
x
(
Target
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
TensorListAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorListAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorListAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorListAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
template
<
>
const
Type
*
Type
::
Get
<
TensorAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
template
<
>
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
switch
(
target
)
{
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
12db9f3c
...
@@ -60,6 +60,8 @@ class DataTypeBase {
...
@@ -60,6 +60,8 @@ class DataTypeBase {
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
// in some IO kernels those doesn't care the data.
Tensor_Any
,
Tensor_Any
,
// Used by feed or fetch op.
TensorList_Any
,
NumTypes
,
// Must remains as last defined ID.
NumTypes
,
// Must remains as last defined ID.
};
};
...
@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
...
@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
DATALAYOUT
(
kAny
))
{}
};
};
// A list of tensor, and no assumption on the data layout or data type.
class
TensorListAnyTy
:
public
Type
{
public:
TensorListAnyTy
(
TargetType
target
)
:
Type
(
ID
::
TensorList_Any
,
"TensorList_Any"
,
false
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
class
TensorFp32NCHWTy
:
public
Type
{
class
TensorFp32NCHWTy
:
public
Type
{
public:
public:
TensorFp32NCHWTy
(
TargetType
target
)
TensorFp32NCHWTy
(
TargetType
target
)
...
...
paddle/fluid/lite/kernels/host/CMakeLists.txt
浏览文件 @
12db9f3c
...
@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
...
@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
fetch_compute_host SRCS fetch_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
host_kernels DEPS
cc_library
(
host_kernels DEPS
feed_compute_host
fetch_compute_host
fc_compute_host
fc_compute_host
relu_compute_host
relu_compute_host
mul_compute_host
mul_compute_host
scale_compute_host
scale_compute_host
feed_compute_host
DEPS
${
lite_kernel_deps
}
DEPS
${
lite_kernel_deps
}
)
)
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
12db9f3c
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/type_system.h"
...
@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
void
Run
()
override
{
auto
&
the
param
=
Param
<
operators
::
FeedParam
>
();
auto
&
param
=
Param
<
operators
::
FeedParam
>
();
const
Tensor
&
feed_item
=
theparam
.
feed_list
->
at
(
the
param
.
col
);
const
Tensor
&
feed_item
=
param
.
feed_list
->
at
(
param
.
col
);
the
param
.
out
->
CopyDataFrom
(
feed_item
);
param
.
out
->
CopyDataFrom
(
feed_item
);
}
}
};
};
...
@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Fp32NCHW
Ty
>
(
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Any
Ty
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/fetch_compute.cc
0 → 100644
浏览文件 @
12db9f3c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
host
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
auto
&
param
=
Param
<
operators
::
FetchParam
>
();
auto
*
fetch_list
=
param
.
fetch_list
;
if
(
fetch_list
->
size
()
<=
static_cast
<
size_t
>
(
param
.
col
))
{
fetch_list
->
resize
(
param
.
col
+
1
);
}
auto
&
dst
=
fetch_list
->
at
(
param
.
col
);
dst
.
CopyDataFrom
(
*
param
.
input
);
}
};
}
// namespace host
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorListAnyTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
12db9f3c
...
@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
12db9f3c
...
@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
...
@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library
(
mul_op_lite SRCS mul_op.cc DEPS op_lite
)
cc_library
(
mul_op_lite SRCS mul_op.cc DEPS op_lite
)
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS op_lite
)
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS op_lite
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS op_lite
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS op_lite
)
cc_library
(
fetch_op_lite SRCS fetch_op.cc DEPS op_lite
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite
)
cc_library
(
op_params_lite SRCS op_params.cc DEPS tensor_lite
)
cc_library
(
op_params_lite SRCS op_params.cc DEPS tensor_lite
)
...
@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
...
@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
mul_op_lite
mul_op_lite
scale_op_lite
scale_op_lite
feed_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
io_copy_op_lite
)
)
...
...
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
12db9f3c
...
@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
...
@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
return
true
;
return
true
;
}
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"fc"
;
}
std
::
string
DebugString
()
const
override
{
return
"fc"
;
}
private:
private:
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
12db9f3c
...
@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
...
@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
protected:
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
...
@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
kernel_
->
SetParam
(
param_
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
0 → 100644
浏览文件 @
12db9f3c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FetchOp
:
public
OpLite
{
public:
explicit
FetchOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
CHECK_OR_FALSE
(
param_
.
input
);
CHECK_OR_FALSE
(
param_
.
fetch_list
);
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
param_
.
input
=
&
x
->
Get
<
Tensor
>
();
auto
_out
=
opdesc
.
Output
(
"Out"
).
front
();
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
return
true
;
}
std
::
string
DebugString
()
const
override
{
return
"fetch"
;
}
private:
mutable
FetchParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fetch
,
paddle
::
lite
::
operators
::
FetchOp
);
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
12db9f3c
...
@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
...
@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
bool
Run
()
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
12db9f3c
...
@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
...
@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
12db9f3c
...
@@ -25,8 +25,14 @@ namespace lite {
...
@@ -25,8 +25,14 @@ namespace lite {
namespace
operators
{
namespace
operators
{
struct
FeedParam
{
struct
FeedParam
{
const
std
::
vector
<
Tensor
>*
feed_list
;
const
std
::
vector
<
Tensor
>*
feed_list
{};
Tensor
*
out
;
Tensor
*
out
{};
int
col
;
};
struct
FetchParam
{
const
Tensor
*
input
{};
std
::
vector
<
Tensor
>*
fetch_list
{};
int
col
;
int
col
;
};
};
...
@@ -69,8 +75,8 @@ struct IoCopyParam {
...
@@ -69,8 +75,8 @@ struct IoCopyParam {
Tensor
*
y
{};
Tensor
*
y
{};
};
};
using
param_t
=
using
param_t
=
variant
<
FeedParam
,
FetchParam
,
FcParam
,
ReluParam
,
MulParam
,
variant
<
FeedParam
,
FcParam
,
ReluParam
,
MulParam
,
ScaleParam
,
IoCopyParam
>
;
ScaleParam
,
IoCopyParam
>
;
}
// namespace operators
}
// namespace operators
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
12db9f3c
...
@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
...
@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
private:
private:
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
12db9f3c
...
@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
...
@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
return
true
;
return
true
;
}
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录