Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
12db9f3c
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
12db9f3c
编写于
4月 22, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make the predictor works from a faked model
上级
610ce3ae
变更
26
显示空白变更内容
内联
并排
Showing
26 changed file
with
315 addition
and
62 deletion
+315
-62
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+1
-1
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+23
-2
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+13
-23
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+1
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+2
-2
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+23
-4
paddle/fluid/lite/core/op_executor_test.cc
paddle/fluid/lite/core/op_executor_test.cc
+7
-11
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-0
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+16
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+11
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+21
-7
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+39
-0
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+9
-0
paddle/fluid/lite/kernels/host/CMakeLists.txt
paddle/fluid/lite/kernels/host/CMakeLists.txt
+3
-1
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+4
-5
paddle/fluid/lite/kernels/host/fetch_compute.cc
paddle/fluid/lite/kernels/host/fetch_compute.cc
+50
-0
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+4
-0
paddle/fluid/lite/operators/CMakeLists.txt
paddle/fluid/lite/operators/CMakeLists.txt
+2
-0
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-0
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+4
-0
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+61
-0
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+2
-0
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+1
-0
paddle/fluid/lite/operators/op_params.h
paddle/fluid/lite/operators/op_params.h
+10
-4
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-0
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+2
-0
未找到文件。
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
12db9f3c
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite
optimizer_lite
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
12db9f3c
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -30,7 +31,6 @@ class Predictor {
...
@@ -30,7 +31,6 @@ class Predictor {
void
Build
(
const
std
::
string
&
model_path
,
void
Build
(
const
std
::
string
&
model_path
,
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
scope_
.
get
())
<<
"duplicate build found"
;
framework
::
proto
::
ProgramDesc
prog
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
...
@@ -38,10 +38,31 @@ class Predictor {
...
@@ -38,10 +38,31 @@ class Predictor {
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Program
program
(
prog_desc
,
scope_
,
valid_places
);
Optimizer
optimizer
;
Optimizer
optimizer
;
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
.
GenRuntimeProgram
();
program_
=
optimizer
.
GenRuntimeProgram
();
}
}
// Get offset-th col of feed.
Tensor
*
GetInput
(
size_t
offset
)
{
auto
*
_feed_list
=
program_
->
exec_scope
()
->
FindVar
(
"feed"
);
CHECK
(
_feed_list
)
<<
"no feed variable in exec_scope"
;
auto
*
feed_list
=
_feed_list
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
if
(
offset
>=
feed_list
->
size
())
{
feed_list
->
resize
(
offset
+
1
);
}
return
&
feed_list
->
at
(
offset
);
}
const
Tensor
*
GetOutput
(
size_t
offset
)
{
auto
*
_fetch_list
=
program_
->
exec_scope
()
->
FindVar
(
"fetch"
);
CHECK
(
_fetch_list
)
<<
"no fatch variable in exec_scope"
;
auto
fetch_list
=
_fetch_list
->
Get
<
std
::
vector
<
Tensor
>>
();
CHECK_LT
(
offset
,
fetch_list
.
size
())
<<
"offset "
<<
offset
<<
" overflow"
;
return
&
fetch_list
.
at
(
offset
);
}
void
Run
()
{
program_
->
Run
();
}
void
Run
()
{
program_
->
Run
();
}
private:
private:
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
12db9f3c
...
@@ -14,36 +14,22 @@
...
@@ -14,36 +14,22 @@
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_executor.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
TEST
(
CXXApi
,
raw
)
{
Scope
scope
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
"/home/chunwei/project2/models/model2"
,
&
scope
,
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
lite
::
Executor
executor
(
&
scope
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
x
=
scope
.
Var
(
"a"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
100
,
100
});
x
->
mutable_data
<
float
>
();
executor
.
PrepareWorkspace
(
prog_desc
);
executor
.
Build
(
prog_desc
);
executor
.
Run
();
}
TEST
(
CXXApi
,
test
)
{
TEST
(
CXXApi
,
test
)
{
lite
::
Predictor
predictor
;
lite
::
Predictor
predictor
;
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
predictor
.
Build
(
"/home/chunwei/project2/models/model2"
,
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
*
x
=
predictor
.
GetInputTensor
(
"a"
);
x
->
Resize
({
100
,
200
});
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
x
->
mutable_data
<
float
>
();
input_tensor
->
Resize
({
100
,
100
});
input_tensor
->
mutable_data
<
float
>
();
predictor
.
Run
();
}
}
}
// namespace lite
}
// namespace lite
...
@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
...
@@ -52,6 +38,10 @@ TEST(CXXApi, test) {
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
scale
);
USE_LITE_OP
(
scale
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
);
USE_LITE_OP
(
feed
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
);
USE_LITE_OP
(
fetch
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
def
);
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
12db9f3c
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
mir_passes
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_passes
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
SRCS static_kernel_pick_pass.cc
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
12db9f3c
...
@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -36,8 +36,8 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
inst
.
place
,
inst
.
op_type
,
tmp
);
inst
.
place
,
inst
.
op_type
,
tmp
);
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
CHECK
(
type
)
<<
"no param type found for "
<<
inst
.
op_type
<<
":"
<<
name
<<
" "
<<
inst
.
place
;
<<
" "
<<
inst
.
place
;
if
(
type
->
tensor_place
!=
in
st
.
place
)
{
if
(
type
->
tensor_place
!=
in
->
AsArgument
()
.
place
)
{
LOG
(
INFO
)
<<
"found IO unmatched tensor
"
;
LOG
(
INFO
)
<<
"found IO unmatched tensor
: "
<<
in
->
AsArgument
().
name
;
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
12db9f3c
...
@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
...
@@ -36,7 +36,7 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param program: the op program
// @param valid_places: the valid places user set for the system.
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
// create
inputs
// create
temporary nodes.
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
...
@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
...
@@ -45,20 +45,33 @@ class SSAGraph : GraphBase {
arguments_
[
name
]
=
&
new_node
;
arguments_
[
name
]
=
&
new_node
;
}
}
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
for
(
auto
&
op
:
program
.
ops
)
{
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
node_storage_
.
back
().
AsInstruct
(
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
,
op
->
op_info
());
for
(
auto
&
kernel
:
kernels
)
{
op
->
AttachKernel
(
kernel
.
get
());
}
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
,
op
->
op_info
());
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
arguments_
.
a
t
(
name
);
auto
*
arg
=
Argumen
t
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
}
...
@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
...
@@ -79,6 +92,12 @@ class SSAGraph : GraphBase {
MarkArgumentWeights
(
program
);
MarkArgumentWeights
(
program
);
}
}
mir
::
Node
*
Argument
(
const
std
::
string
&
name
)
{
auto
it
=
arguments_
.
find
(
name
);
CHECK
(
it
!=
arguments_
.
end
())
<<
"no argument called "
<<
name
;
return
it
->
second
;
}
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
// The inputs of the graph.
// The inputs of the graph.
...
...
paddle/fluid/lite/core/op_executor_test.cc
浏览文件 @
12db9f3c
...
@@ -20,12 +20,9 @@ namespace paddle {
...
@@ -20,12 +20,9 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
TEST
(
executor
,
test
)
{
TEST
(
executor
,
test
)
{
std
::
vector
<
OpLite
::
Place
>
valid_places
{
std
::
vector
<
Place
>
valid_places
{
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
OpLite
::
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
Scope
scope
;
auto
scope
=
std
::
make_shared
<
lite
::
Scope
>
();
Executor
executor
(
&
scope
,
valid_places
);
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
program
.
MutableBlock
(
0
)
->
Var
(
"x"
);
program
.
MutableBlock
(
0
)
->
Var
(
"x"
);
...
@@ -42,19 +39,18 @@ TEST(executor, test) {
...
@@ -42,19 +39,18 @@ TEST(executor, test) {
op_desc
.
SetAttr
(
"in_num_col_dims"
,
static_cast
<
int
>
(
1
));
op_desc
.
SetAttr
(
"in_num_col_dims"
,
static_cast
<
int
>
(
1
));
program
.
Flush
();
program
.
Flush
();
auto
*
w
=
scope
.
Var
(
"w"
)
->
GetMutable
<
Tensor
>
();
auto
*
w
=
scope
->
Var
(
"w"
)
->
GetMutable
<
Tensor
>
();
w
->
Resize
({
20
,
20
});
w
->
Resize
({
20
,
20
});
auto
*
x
=
scope
.
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
auto
*
x
=
scope
->
Var
(
"x"
)
->
GetMutable
<
Tensor
>
();
x
->
Resize
({
1
,
10
,
20
});
x
->
Resize
({
1
,
10
,
20
});
auto
*
bias
=
scope
.
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
auto
*
bias
=
scope
->
Var
(
"bias"
)
->
GetMutable
<
Tensor
>
();
bias
->
Resize
({
1
,
20
});
bias
->
Resize
({
1
,
20
});
bias
->
mutable_data
<
float
>
();
bias
->
mutable_data
<
float
>
();
w
->
mutable_data
<
float
>
();
w
->
mutable_data
<
float
>
();
x
->
mutable_data
<
float
>
();
x
->
mutable_data
<
float
>
();
executor
.
PrepareWorkspace
(
program
);
lite
::
Executor
executor
(
program
,
scope
,
valid_places
);
executor
.
Build
(
program
);
executor
.
Run
();
executor
.
Run
();
}
}
...
@@ -62,4 +58,4 @@ TEST(executor, test) {
...
@@ -62,4 +58,4 @@ TEST(executor, test) {
}
// namespace paddle
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
paddle/fluid/lite/core/op_lite.h
浏览文件 @
12db9f3c
...
@@ -101,6 +101,9 @@ class OpLite : public Registry {
...
@@ -101,6 +101,9 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
lite
::
Scope
*
scope
)
=
0
;
// Assign op param to kernel.
virtual
void
AttachKernel
(
KernelBase
*
kernel
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
...
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
12db9f3c
...
@@ -11,3 +11,19 @@
...
@@ -11,3 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
namespace
paddle
{
namespace
lite
{
void
Optimizer
::
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
)
{
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
CHECK
(
pass
);
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/optimizer.h
浏览文件 @
12db9f3c
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -30,11 +31,14 @@ namespace lite {
...
@@ -30,11 +31,14 @@ namespace lite {
class
Optimizer
{
class
Optimizer
{
public:
public:
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
RunPasses
();
RunPasses
();
exec_scope_
=
program
.
exec_scope
;
}
}
// Generate a new program based on the mir graph.
// Generate a new program based on the mir graph.
...
@@ -42,7 +46,10 @@ class Optimizer {
...
@@ -42,7 +46,10 @@ class Optimizer {
std
::
unique_ptr
<
Program
>
res
;
std
::
unique_ptr
<
Program
>
res
;
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
auto
pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
"generate_program_pass"
);
return
pass
->
GenProgram
();
auto
program
=
pass
->
GenProgram
();
CHECK
(
exec_scope_
);
program
->
set_exec_scope
(
exec_scope_
);
return
program
;
}
}
// Generate C++ code which combines the inference program, model and weights.
// Generate C++ code which combines the inference program, model and weights.
...
@@ -54,6 +61,8 @@ class Optimizer {
...
@@ -54,6 +61,8 @@ class Optimizer {
}
}
protected:
protected:
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
// Run the default passes registered in the PassManager.
// Run the default passes registered in the PassManager.
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
void
RunPasses
()
{
mir
::
PassManager
::
Global
().
Run
(
graph_
);
}
...
@@ -62,6 +71,7 @@ class Optimizer {
...
@@ -62,6 +71,7 @@ class Optimizer {
private:
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
lite
::
Scope
*
exec_scope_
{};
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
12db9f3c
...
@@ -35,23 +35,22 @@ struct Program {
...
@@ -35,23 +35,22 @@ struct Program {
std
::
list
<
std
::
shared_ptr
<
OpLite
>>
ops
;
std
::
list
<
std
::
shared_ptr
<
OpLite
>>
ops
;
// the scope to run the kernels, NOTE not the root scope.
// the scope to run the kernels, NOTE not the root scope.
std
::
shared_ptr
<
lite
::
Scope
>
scope
;
std
::
shared_ptr
<
lite
::
Scope
>
scope
;
std
::
vector
<
Place
>
valid_places
;
// Runtime scope.
// Runtime scope.
lite
::
Scope
*
exec_scope
{};
lite
::
Scope
*
exec_scope
{};
const
framework
::
ProgramDesc
desc
;
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
Program
(
const
framework
::
ProgramDesc
&
desc
,
Program
(
const
framework
::
ProgramDesc
&
desc
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
scope
=
root
;
:
scope
(
root
),
valid_places
(
valid_places
),
desc
(
desc
)
{
PrepareWorkspace
(
desc
);
PrepareWorkspace
(
desc
);
Build
(
desc
,
valid_places
);
Build
(
desc
,
valid_places
);
}
}
std
::
unique_ptr
<
Program
>
Clone
()
const
{
std
::
unique_ptr
<
Program
>
Clone
()
const
{
std
::
unique_ptr
<
Program
>
res
(
new
Program
(
scope
));
std
::
unique_ptr
<
Program
>
res
(
new
Program
(
desc
,
scope
,
valid_places
));
res
->
tmp_vars
=
tmp_vars
;
res
->
weights
=
weights
;
res
->
ops
=
ops
;
return
res
;
return
res
;
}
}
...
@@ -64,7 +63,7 @@ struct Program {
...
@@ -64,7 +63,7 @@ struct Program {
// Create operators.
// Create operators.
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
auto
op_type
=
op_desc
->
Type
();
auto
op_type
=
op_desc
->
Type
();
if
(
op_type
==
"feed"
||
op_type
==
"fetch"
)
continue
;
//
if (op_type == "feed" || op_type == "fetch") continue;
LOG
(
INFO
)
<<
"create Op ["
<<
op_type
<<
"]"
;
LOG
(
INFO
)
<<
"create Op ["
<<
op_type
<<
"]"
;
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
// pick initial kernel
...
@@ -77,11 +76,22 @@ struct Program {
...
@@ -77,11 +76,22 @@ struct Program {
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
exec_scope
=
&
scope
->
NewScope
();
exec_scope
=
&
scope
->
NewScope
();
// Create Feed and Fetch var.
scope
->
Var
(
"feed"
)
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
scope
->
Var
(
"fetch"
)
->
GetMutable
<
std
::
vector
<
Tensor
>>
();
tmp_vars
.
push_back
(
"feed"
);
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
if
(
!
var_desc
->
Persistable
())
{
if
(
!
var_desc
->
Persistable
())
{
LOG
(
INFO
)
<<
"get tmp var "
<<
var_desc
->
Name
();
tmp_vars
.
push_back
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
auto
*
var
=
exec_scope
->
Var
(
var_desc
->
Name
());
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
LOG
(
INFO
)
<<
"create tmp var "
<<
var_desc
->
Name
()
<<
" "
<<
var
;
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
LOG
(
INFO
)
<<
"get weight var "
<<
var_desc
->
Name
();
weights
.
push_back
(
var_desc
->
Name
());
}
}
}
}
}
}
...
@@ -118,11 +128,15 @@ class RuntimeProgram {
...
@@ -118,11 +128,15 @@ class RuntimeProgram {
}
}
}
}
void
set_exec_scope
(
lite
::
Scope
*
x
)
{
exec_scope_
=
x
;
}
lite
::
Scope
*
exec_scope
()
{
return
exec_scope_
;
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
private:
private:
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
std
::
vector
<
Instruction
>
instructions_
;
std
::
vector
<
Instruction
>
instructions_
;
lite
::
Scope
*
exec_scope_
{};
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
12db9f3c
...
@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
...
@@ -48,6 +48,45 @@ const Type* Type::Get<UnsupportedTy>(TargetType target) {
DataLayoutType
::
kNCHW
>
();
DataLayoutType
::
kNCHW
>
();
}
}
template
<
TargetType
Target
>
TensorListAnyTy
*
GetTensorListAnyTy
()
{
static
TensorListAnyTy
x
(
Target
);
return
&
x
;
}
template
<
TargetType
Target
>
TensorAnyTy
*
GetTensorAnyTy
()
{
static
TensorAnyTy
x
(
Target
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
TensorListAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorListAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorListAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorListAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
template
<
>
const
Type
*
Type
::
Get
<
TensorAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
template
<
>
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
switch
(
target
)
{
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
12db9f3c
...
@@ -60,6 +60,8 @@ class DataTypeBase {
...
@@ -60,6 +60,8 @@ class DataTypeBase {
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
// in some IO kernels those doesn't care the data.
Tensor_Any
,
Tensor_Any
,
// Used by feed or fetch op.
TensorList_Any
,
NumTypes
,
// Must remains as last defined ID.
NumTypes
,
// Must remains as last defined ID.
};
};
...
@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
...
@@ -146,6 +148,13 @@ class TensorAnyTy : public Type {
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
DATALAYOUT
(
kAny
))
{}
};
};
// A list of tensor, and no assumption on the data layout or data type.
class
TensorListAnyTy
:
public
Type
{
public:
TensorListAnyTy
(
TargetType
target
)
:
Type
(
ID
::
TensorList_Any
,
"TensorList_Any"
,
false
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
class
TensorFp32NCHWTy
:
public
Type
{
class
TensorFp32NCHWTy
:
public
Type
{
public:
public:
TensorFp32NCHWTy
(
TargetType
target
)
TensorFp32NCHWTy
(
TargetType
target
)
...
...
paddle/fluid/lite/kernels/host/CMakeLists.txt
浏览文件 @
12db9f3c
...
@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
...
@@ -3,13 +3,15 @@ cc_library(relu_compute_host SRCS relu_compute.cc DEPS ${lite_kernel_deps})
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
fetch_compute_host SRCS fetch_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
host_kernels DEPS
cc_library
(
host_kernels DEPS
feed_compute_host
fetch_compute_host
fc_compute_host
fc_compute_host
relu_compute_host
relu_compute_host
mul_compute_host
mul_compute_host
scale_compute_host
scale_compute_host
feed_compute_host
DEPS
${
lite_kernel_deps
}
DEPS
${
lite_kernel_deps
}
)
)
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
12db9f3c
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/type_system.h"
...
@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -26,9 +25,9 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
void
Run
()
override
{
auto
&
the
param
=
Param
<
operators
::
FeedParam
>
();
auto
&
param
=
Param
<
operators
::
FeedParam
>
();
const
Tensor
&
feed_item
=
theparam
.
feed_list
->
at
(
the
param
.
col
);
const
Tensor
&
feed_item
=
param
.
feed_list
->
at
(
param
.
col
);
the
param
.
out
->
CopyDataFrom
(
feed_item
);
param
.
out
->
CopyDataFrom
(
feed_item
);
}
}
};
};
...
@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -39,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Fp32NCHW
Ty
>
(
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
Tensor
Any
Ty
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/fetch_compute.cc
0 → 100644
浏览文件 @
12db9f3c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
host
{
class
FetchCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
auto
&
param
=
Param
<
operators
::
FetchParam
>
();
auto
*
fetch_list
=
param
.
fetch_list
;
if
(
fetch_list
->
size
()
<=
static_cast
<
size_t
>
(
param
.
col
))
{
fetch_list
->
resize
(
param
.
col
+
1
);
}
auto
&
dst
=
fetch_list
->
at
(
param
.
col
);
dst
.
CopyDataFrom
(
*
param
.
input
);
}
};
}
// namespace host
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorAnyTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorListAnyTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
12db9f3c
...
@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -52,4 +52,8 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/operators/CMakeLists.txt
浏览文件 @
12db9f3c
...
@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
...
@@ -3,6 +3,7 @@ cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library
(
mul_op_lite SRCS mul_op.cc DEPS op_lite
)
cc_library
(
mul_op_lite SRCS mul_op.cc DEPS op_lite
)
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS op_lite
)
cc_library
(
scale_op_lite SRCS scale_op.cc DEPS op_lite
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS op_lite
)
cc_library
(
feed_op_lite SRCS feed_op.cc DEPS op_lite
)
cc_library
(
fetch_op_lite SRCS fetch_op.cc DEPS op_lite
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite
)
cc_library
(
io_copy_op_lite SRCS io_copy_op.cc DEPS op_lite
)
cc_library
(
op_params_lite SRCS op_params.cc DEPS tensor_lite
)
cc_library
(
op_params_lite SRCS op_params.cc DEPS tensor_lite
)
...
@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
...
@@ -12,6 +13,7 @@ cc_library(ops_lite DEPS
mul_op_lite
mul_op_lite
scale_op_lite
scale_op_lite
feed_op_lite
feed_op_lite
fetch_op_lite
io_copy_op_lite
io_copy_op_lite
)
)
...
...
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
12db9f3c
...
@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
...
@@ -67,6 +67,8 @@ class FcOpLite : public OpLite {
return
true
;
return
true
;
}
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"fc"
;
}
std
::
string
DebugString
()
const
override
{
return
"fc"
;
}
private:
private:
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
12db9f3c
...
@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
...
@@ -31,6 +31,9 @@ class FeedOp : public OpLite {
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
protected:
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
...
@@ -48,6 +51,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
kernel_
->
SetParam
(
param_
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
0 → 100644
浏览文件 @
12db9f3c
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
FetchOp
:
public
OpLite
{
public:
explicit
FetchOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
{
CHECK_OR_FALSE
(
param_
.
input
);
CHECK_OR_FALSE
(
param_
.
fetch_list
);
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
param_
.
input
=
&
x
->
Get
<
Tensor
>
();
auto
_out
=
opdesc
.
Output
(
"Out"
).
front
();
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
));
return
true
;
}
std
::
string
DebugString
()
const
override
{
return
"fetch"
;
}
private:
mutable
FetchParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
fetch
,
paddle
::
lite
::
operators
::
FetchOp
);
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
12db9f3c
...
@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
...
@@ -28,6 +28,8 @@ class IoCopyOp : public OpLite {
bool
Run
()
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
12db9f3c
...
@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
...
@@ -36,6 +36,7 @@ class MulOpLite : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
...
paddle/fluid/lite/operators/op_params.h
浏览文件 @
12db9f3c
...
@@ -25,8 +25,14 @@ namespace lite {
...
@@ -25,8 +25,14 @@ namespace lite {
namespace
operators
{
namespace
operators
{
struct
FeedParam
{
struct
FeedParam
{
const
std
::
vector
<
Tensor
>*
feed_list
;
const
std
::
vector
<
Tensor
>*
feed_list
{};
Tensor
*
out
;
Tensor
*
out
{};
int
col
;
};
struct
FetchParam
{
const
Tensor
*
input
{};
std
::
vector
<
Tensor
>*
fetch_list
{};
int
col
;
int
col
;
};
};
...
@@ -69,8 +75,8 @@ struct IoCopyParam {
...
@@ -69,8 +75,8 @@ struct IoCopyParam {
Tensor
*
y
{};
Tensor
*
y
{};
};
};
using
param_t
=
using
param_t
=
variant
<
FeedParam
,
FetchParam
,
FcParam
,
ReluParam
,
MulParam
,
variant
<
FeedParam
,
FcParam
,
ReluParam
,
MulParam
,
ScaleParam
,
IoCopyParam
>
;
ScaleParam
,
IoCopyParam
>
;
}
// namespace operators
}
// namespace operators
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
12db9f3c
...
@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
...
@@ -34,6 +34,7 @@ class ReluOp : public OpLite {
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
private:
private:
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
12db9f3c
...
@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
...
@@ -43,6 +43,8 @@ class ScaleOp : public OpLite {
return
true
;
return
true
;
}
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录