Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
70540d1b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70540d1b
编写于
4月 29, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a new lightweight OpDesc compatible with the original framework::OpDesc
to support mobile
上级
6e19097b
变更
32
隐藏空白更改
内联
并排
Showing
32 changed file
with
929 addition
and
69 deletion
+929
-69
CMakeLists.txt
CMakeLists.txt
+2
-1
cmake/configure.cmake
cmake/configure.cmake
+4
-0
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+3
-3
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+1
-2
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+2
-2
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+2
-16
paddle/fluid/lite/core/op_executor.h
paddle/fluid/lite/core/op_executor.h
+2
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-1
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-5
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+15
-13
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+8
-0
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+15
-0
paddle/fluid/lite/model_parser/compatible_pb.h
paddle/fluid/lite/model_parser/compatible_pb.h
+45
-0
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
+2
-0
paddle/fluid/lite/model_parser/pb/block_desc.cc
paddle/fluid/lite/model_parser/pb/block_desc.cc
+13
-0
paddle/fluid/lite/model_parser/pb/block_desc.h
paddle/fluid/lite/model_parser/pb/block_desc.h
+123
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+15
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+234
-0
paddle/fluid/lite/model_parser/pb/program_desc.cc
paddle/fluid/lite/model_parser/pb/program_desc.cc
+13
-0
paddle/fluid/lite/model_parser/pb/program_desc.h
paddle/fluid/lite/model_parser/pb/program_desc.h
+13
-0
paddle/fluid/lite/model_parser/pb/var_desc.cc
paddle/fluid/lite/model_parser/pb/var_desc.cc
+271
-0
paddle/fluid/lite/model_parser/pb/var_desc.h
paddle/fluid/lite/model_parser/pb/var_desc.h
+123
-0
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-4
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+2
-3
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+2
-3
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+1
-2
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+3
-4
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+4
-6
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+2
-1
未找到文件。
CMakeLists.txt
浏览文件 @
70540d1b
...
@@ -186,7 +186,8 @@ endif()
...
@@ -186,7 +186,8 @@ endif()
# for lite
# for lite
option
(
LITE_WITH_CUDA
"Enable CUDA in lite mode"
ON
)
option
(
LITE_WITH_CUDA
"Enable CUDA in lite mode"
ON
)
option
(
LITE_WITH_X86
"Enable X86 in lite mode"
ON
)
option
(
LITE_WITH_X86
"Enable X86 in lite mode"
ON
)
option
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"Enable light-weight framework"
ON
)
include
(
external/threadpool
)
include
(
external/threadpool
)
include
(
flags
)
# set paddle compile flags
include
(
flags
)
# set paddle compile flags
...
...
cmake/configure.cmake
浏览文件 @
70540d1b
...
@@ -171,3 +171,7 @@ endif()
...
@@ -171,3 +171,7 @@ endif()
if
(
LITE_WITH_X86
)
if
(
LITE_WITH_X86
)
add_definitions
(
"-DLITE_WITH_X86"
)
add_definitions
(
"-DLITE_WITH_X86"
)
endif
()
endif
()
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
add_definitions
(
"-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK"
)
endif
()
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
70540d1b
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
if
(
LITE_WITH_CUDA
)
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
else
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels
)
endif
()
endif
()
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
70540d1b
...
@@ -33,9 +33,8 @@ class Predictor {
...
@@ -33,9 +33,8 @@ class Predictor {
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
Program
program
(
prog
_desc
,
scope_
,
valid_places
);
Program
program
(
prog
,
scope_
,
valid_places
);
Optimizer
optimizer
;
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
70540d1b
...
@@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
...
@@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite
compatible_pb_lite
)
cc_library
(
op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
cc_library
(
op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework
#TODO(Superjomn) remove these dependencies from original framework
proto_desc
)
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
70540d1b
...
@@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
...
@@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
}
}
}
}
void
UpdateOpdescInputName
(
framework
::
OpDesc
*
desc
,
const
std
::
string
&
old_arg_name
,
const
std
::
string
&
new_arg_name
)
{
for
(
auto
&
item
:
*
desc
->
Proto
()
->
mutable_inputs
())
{
for
(
int
i
=
0
;
i
<
item
.
mutable_arguments
()
->
size
();
i
++
)
{
auto
*
arg
=
item
.
mutable_arguments
(
i
);
if
(
*
arg
==
old_arg_name
)
{
*
arg
=
new_arg_name
;
}
}
}
}
void
IoComplementPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
void
IoComplementPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
Node
*
inst_node
,
...
@@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
...
@@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
// Create IoCopy Instruction.
framework
::
OpDesc
op_desc
;
lite
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
Flush
();
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsInstruct
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsInstruct
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
...
@@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
...
@@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
auto
desc_dummy
=
inst_node
->
AsInstruct
().
op
->
op_info
()
->
desc
();
auto
desc_dummy
=
inst_node
->
AsInstruct
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
framework
::
OpDesc
desc_fake
(
desc_dummy
,
nullptr
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
inst_node
->
AsInstruct
().
op
->
scope
());
...
...
paddle/fluid/lite/core/op_executor.h
浏览文件 @
70540d1b
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
/*
// The Executor is used to run the operators.
// The Executor is used to run the operators.
class Executor {
class Executor {
public:
public:
...
@@ -63,6 +64,7 @@ class RuntimeExecutor {
...
@@ -63,6 +64,7 @@ class RuntimeExecutor {
private:
private:
RuntimeProgram* program_{};
RuntimeProgram* program_{};
};
};
*/
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
70540d1b
...
@@ -76,7 +76,7 @@ bool OpLite::Run() {
...
@@ -76,7 +76,7 @@ bool OpLite::Run() {
return
true
;
return
true
;
}
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
OpLite
::
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
scope
);
CHECK
(
scope
);
scope_
=
scope
;
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
70540d1b
...
@@ -19,12 +19,11 @@
...
@@ -19,12 +19,11 @@
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -82,7 +81,7 @@ class OpLite : public Registry {
...
@@ -82,7 +81,7 @@ class OpLite : public Registry {
virtual
bool
Run
();
virtual
bool
Run
();
// Link the external execution environ to internal context.
// Link the external execution environ to internal context.
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
...
@@ -109,8 +108,7 @@ class OpLite : public Registry {
...
@@ -109,8 +108,7 @@ class OpLite : public Registry {
protected:
protected:
// Attach it with the runtime environment.
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
virtual
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
lite
::
Scope
*
scope
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
70540d1b
...
@@ -38,10 +38,10 @@ struct Program {
...
@@ -38,10 +38,10 @@ struct Program {
std
::
vector
<
Place
>
valid_places
;
std
::
vector
<
Place
>
valid_places
;
// Runtime scope.
// Runtime scope.
lite
::
Scope
*
exec_scope
{};
lite
::
Scope
*
exec_scope
{};
const
framework
::
ProgramDesc
desc
;
const
framework
::
proto
::
ProgramDesc
desc
;
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
Program
(
const
framework
::
ProgramDesc
&
desc
,
Program
(
const
framework
::
proto
::
ProgramDesc
&
desc
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
vector
<
Place
>&
valid_places
)
const
std
::
vector
<
Place
>&
valid_places
)
:
scope
(
root
),
valid_places
(
valid_places
),
desc
(
desc
)
{
:
scope
(
root
),
valid_places
(
valid_places
),
desc
(
desc
)
{
...
@@ -56,24 +56,25 @@ struct Program {
...
@@ -56,24 +56,25 @@ struct Program {
private:
private:
// Build from a program and scope.
// Build from a program and scope.
void
Build
(
const
framework
::
ProgramDesc
&
program
,
void
Build
(
const
framework
::
proto
::
ProgramDesc
&
program
,
const
std
::
vector
<
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
// Create operators.
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
auto
op_type
=
op_desc
->
Type
();
lite
::
OpDesc
op_desc
(
proto_op_desc
);
auto
op_type
=
op_desc
.
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
// pick initial kernel
ops
.
back
()
->
PickKernel
(
valid_places
);
ops
.
back
()
->
PickKernel
(
valid_places
);
ops
.
back
()
->
Attach
(
*
op_desc
,
exec_scope
);
ops
.
back
()
->
Attach
(
op_desc
,
exec_scope
);
}
}
}
}
// Create temporary variables.
// Create temporary variables.
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
void
PrepareWorkspace
(
const
framework
::
proto
::
ProgramDesc
&
program
)
{
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
exec_scope
=
&
scope
->
NewScope
();
exec_scope
=
&
scope
->
NewScope
();
// Create Feed and Fetch var.
// Create Feed and Fetch var.
...
@@ -82,13 +83,14 @@ struct Program {
...
@@ -82,13 +83,14 @@ struct Program {
tmp_vars
.
push_back
(
"feed"
);
tmp_vars
.
push_back
(
"feed"
);
tmp_vars
.
push_back
(
"fetch"
);
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
for
(
auto
proto_var_desc
:
program
.
blocks
(
0
).
vars
())
{
if
(
!
var_desc
->
Persistable
())
{
lite
::
VarDesc
var_desc
(
proto_var_desc
);
tmp_vars
.
push_back
(
var_desc
->
Name
());
if
(
!
var_desc
.
Persistable
())
{
exec_scope
->
Var
(
var_desc
->
Name
());
tmp_vars
.
push_back
(
var_desc
.
Name
());
exec_scope
->
Var
(
var_desc
.
Name
());
}
else
{
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
if
(
var_desc
.
Name
()
==
"feed"
||
var_desc
.
Name
()
==
"fetch"
)
continue
;
weights
.
push_back
(
var_desc
->
Name
());
weights
.
push_back
(
var_desc
.
Name
());
}
}
}
}
}
}
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
70540d1b
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
)
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
)
cc_library
(
runtime_lite SRCS runtime.cc
)
cc_library
(
runtime_lite SRCS runtime.cc
)
cc_test
(
test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite
)
cc_test
(
test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite
)
else
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
add_subdirectory
(
pb
)
paddle/fluid/lite/model_parser/compatible_pb.cc
0 → 100644
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
paddle/fluid/lite/model_parser/compatible_pb.h
0 → 100644
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements the interface to manipute the protobuf message. We use
* macros to make a compatible interface with the framework::XXDesc and
* lite::pb::XXDesc.
*/
#include "paddle/fluid/framework/framework.pb.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace
paddle
{
namespace
lite
{
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
framework
::
Attribute
;
using
OpDesc
=
framework
::
OpDesc
;
using
VarDesc
=
framework
::
VarDesc
;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
0 → 100644
浏览文件 @
70540d1b
cc_library
(
var_desc_lite SRCS var_desc.cc DEPS framework_proto
)
cc_library
(
op_desc_lite SRCS op_desc.cc DEPS framework_proto
)
paddle/fluid/lite/model_parser/pb/block_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/block_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
lite
{
class
ProgramDesc
;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
class
BlockDesc
{
public:
BlockDesc
(
ProgramDesc
*
prog
,
proto
::
BlockDesc
*
desc
);
BlockDesc
(
const
BlockDesc
&
other
,
proto
::
BlockDesc
*
desc
,
ProgramDesc
*
prog
);
int32_t
ID
()
const
{
return
desc_
->
idx
();
}
int32_t
Parent
()
const
{
return
desc_
->
parent_idx
();
}
int32_t
ForwardBlockID
()
const
{
return
desc_
->
forward_block_idx
();
}
VarDesc
*
Var
(
const
std
::
string
&
name_bytes
);
VarDesc
*
FindVar
(
const
std
::
string
&
name_bytes
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
VarDesc
*
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
&
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
bool
HasVarRecursive
(
const
std
::
string
&
var_name
)
const
;
std
::
set
<
std
::
string
>
LocalVarNames
()
const
{
std
::
set
<
std
::
string
>
var_names
;
for
(
auto
&
var
:
vars_
)
{
var_names
.
insert
(
var
.
first
);
}
return
var_names
;
}
std
::
vector
<
VarDesc
*>
AllVars
()
const
;
BlockDesc
*
ParentBlock
()
const
;
BlockDesc
*
ForwardBlock
()
const
;
void
SetForwardBlockID
(
int32_t
forward_block_id
);
OpDesc
*
AppendOp
();
void
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
);
OpDesc
*
PrependOp
();
void
PrependAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
);
OpDesc
*
InsertOp
(
size_t
index
);
/*
* Only remove op itself,
* do nothing to its input and output variables
*/
void
RemoveOp
(
size_t
s
,
size_t
e
);
void
RemoveOpInternal
(
const
OpDesc
*
op_desc
);
void
RemoveVar
(
const
std
::
string
&
name
)
{
vars_
.
erase
(
name
);
}
std
::
vector
<
OpDesc
*>
AllOps
()
const
;
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
OpDesc
*
Op
(
int
idx
)
const
{
return
ops_
.
at
(
idx
).
get
();
}
void
Flush
();
proto
::
BlockDesc
*
Proto
();
ProgramDesc
*
Program
()
const
{
return
this
->
prog_
;
}
private:
ProgramDesc
*
prog_
;
// not_own
proto
::
BlockDesc
*
desc_
;
// not_own
bool
need_update_
;
std
::
deque
<
std
::
unique_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
VarDesc
>>
vars_
;
DISABLE_COPY_AND_ASSIGN
(
BlockDesc
);
};
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements a light-weight OpDesc like the framework::OpDesc. We
* delete the unnecessary methods, and remove the underlying dependencies, such
* as framework::Operator and boost::varient to make it runnable in mobile.
*/
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
using
Attribute
=
variant
<
int
,
float
,
bool
,
std
::
vector
<
std
::
string
>>
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
/*
* The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc.
* Unlike the original one in framework::OpDesc, we remove the local members
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class
OpDesc
{
public:
OpDesc
()
{}
OpDesc
(
const
framework
::
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{}
void
CopyFrom
(
const
OpDesc
&
op_desc
)
{
desc_
=
op_desc
.
ReadonlyProto
();
}
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
void
SetType
(
const
std
::
string
&
type
)
{
desc_
.
set_type
(
type
);
}
// Get the arguments of parameter called `param`
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
{
return
GetArguments
(
desc_
.
inputs
(),
param
);
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
{
return
GetArgumentNames
(
desc_
.
inputs
());
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
{
return
GetArguments
(
desc_
.
outputs
(),
param
);
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
{
return
GetArgumentNames
(
desc_
.
outputs
());
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
return
it
!=
xs
.
end
();
}
framework
::
proto
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
CHECK
(
it
!=
xs
.
end
());
return
it
->
type
();
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
res
;
const
auto
&
xs
=
desc_
.
attrs
();
std
::
transform
(
xs
.
begin
(),
xs
.
end
(),
std
::
back_inserter
(
res
),
[](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
();
});
return
res
;
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find
(
xs
.
begin
(),
xs
.
end
(),
name
);
}
switch
(
typeid
(
T
).
hash_code
())
{
case
typeid
(
int
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
break
;
case
typeid
(
float
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
Attribute
res
;
CHECK
(
it
!=
xs
.
end
());
switch
(
it
->
type
())
{
case
framework
::
proto
::
INT
:
res
.
set
<
int
>
(
it
->
i
());
break
;
case
framework
::
proto
::
FLOAT
:
res
.
set
<
float
>
(
it
->
f
());
break
;
case
framework
::
proto
::
STRING
:
res
.
set
<
std
::
string
>
(
it
->
s
());
break
;
case
framework
::
proto
::
BOOLEAN
:
res
.
set
<
bool
>
(
it
->
b
());
break
;
default:
LOG
(
FATAL
)
<<
"unsupported attr type"
;
}
return
res
;
}
private:
std
::
vector
<
std
::
string
>
GetArguments
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
&
xs
,
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
res
;
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Var
&
it
)
{
return
it
.
parameter
()
==
param
;
});
CHECK
(
it
!=
xs
.
end
());
const
auto
&
ys
=
it
->
arguments
();
std
::
transform
(
ys
.
begin
(),
ys
.
end
(),
std
::
back_inserter
(
res
),
[](
const
std
::
string
&
x
)
{
return
x
;
});
return
res
;
}
void
SetArgument
(
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
*
xs
,
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
auto
it
=
std
::
find_if
(
xs
->
begin
(),
xs
->
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Var
&
it
)
{
return
it
.
parameter
()
==
param
;
});
if
(
it
==
xs
->
end
())
{
auto
*
new_arg
=
xs
->
Add
();
new_arg
->
set_parameter
(
param
);
for
(
const
auto
&
arg
:
args
)
{
*
new_arg
->
mutable_arguments
()
->
Add
()
=
arg
;
}
}
else
{
it
->
mutable_arguments
()
->
Clear
();
for
(
const
auto
&
arg
:
args
)
{
*
it
->
mutable_arguments
()
->
Add
()
=
arg
;
}
}
}
std
::
vector
<
std
::
string
>
GetArgumentNames
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
&
xs
)
const
{
std
::
vector
<
std
::
string
>
res
;
std
::
transform
(
xs
.
begin
(),
xs
.
end
(),
std
::
back_inserter
(
res
),
[](
const
framework
::
proto
::
OpDesc_Var
&
x
)
{
return
x
.
parameter
();
});
return
res
;
}
private:
framework
::
proto
::
OpDesc
desc_
;
};
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/program_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/program_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/var_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
using
namespace
framework
;
proto
::
VarType
::
Type
VarDesc
::
GetType
()
const
{
return
desc_
.
type
().
type
();
}
void
VarDesc
::
SetType
(
proto
::
VarType
::
Type
type
)
{
desc_
.
mutable_type
()
->
set_type
(
type
);
}
void
VarDesc
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
VectorToRepeated
(
dims
,
mutable_tensor_desc
()
->
mutable_dims
());
}
void
VarDesc
::
SetTensorDescNum
(
size_t
num
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
{
auto
*
lod_tensors_ptr
=
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
();
lod_tensors_ptr
->
Clear
();
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
lod_tensors_ptr
->
Add
();
}
return
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'sub_tensor_number' is not supported by the type "
"of var %s."
<<
this
->
Name
();
}
}
size_t
VarDesc
::
GetTensorDescNum
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
return
desc_
.
type
().
reader
().
lod_tensor_size
();
break
;
default:
LOG
(
FATAL
)
<<
"Getting 'sub_tensor_number' is not supported by the type "
"of var %s."
<<
this
->
Name
();
}
}
void
VarDesc
::
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
)
{
if
(
multiple_dims
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given shapes("
<<
multiple_dims
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_dims
.
size
());
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
tensors
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_dims
.
size
();
++
i
)
{
VectorToRepeated
(
multiple_dims
[
i
],
tensors
[
i
]
->
mutable_dims
());
}
}
std
::
vector
<
int64_t
>
VarDesc
::
GetShape
()
const
{
return
RepeatedToVector
(
tensor_desc
().
dims
());
}
std
::
vector
<
std
::
vector
<
int64_t
>>
VarDesc
::
GetShapes
()
const
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
std
::
vector
<
int64_t
>>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
RepeatedToVector
(
tensor_desc
.
dims
()));
}
return
res
;
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
void
VarDesc
::
SetDataTypes
(
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
)
{
if
(
multiple_data_type
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given data types("
<<
multiple_data_type
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_data_type
.
size
());
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
tensor_descs
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_data_type
.
size
();
++
i
)
{
tensor_descs
[
i
]
->
set_data_type
(
multiple_data_type
[
i
]);
}
}
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
return
tensor_desc
().
data_type
();
}
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
proto
::
VarType
::
Type
>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
tensor_desc
.
data_type
());
}
return
res
;
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
desc_
.
mutable_type
()
->
mutable_lod_tensor
()
->
set_lod_level
(
lod_level
);
break
;
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
set_lod_level
(
lod_level
);
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'lod_level' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
void
VarDesc
::
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
)
{
if
(
multiple_lod_level
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given lod_levels("
<<
multiple_lod_level
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_lod_level
.
size
());
}
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
{
size_t
i
=
0
;
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
())
{
lod_tensor
.
set_lod_level
(
multiple_lod_level
[
i
++
]);
}
}
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'lod_levels' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
int32_t
VarDesc
::
GetLoDLevel
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
type
().
lod_tensor
().
lod_level
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
lod_level
();
default:
LOG
(
FATAL
)
<<
"Getting 'lod_level' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
std
::
vector
<
int32_t
>
VarDesc
::
GetLoDLevels
()
const
{
std
::
vector
<
int32_t
>
res
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
res
.
reserve
(
desc_
.
type
().
reader
().
lod_tensor_size
());
for
(
auto
&
lod_tensor
:
desc_
.
type
().
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
lod_level
());
}
return
res
;
break
;
default:
LOG
(
FATAL
)
<<
"Getting 'lod_levels' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
CHECK
(
desc_
.
has_type
())
<<
"The var's type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
type
().
selected_rows
();
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
type
().
lod_tensor
().
tensor
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
tensor
();
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_desc' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
VarDesc
::
tensor_descs
()
const
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
for
(
const
auto
&
lod_tensor
:
desc_
.
type
().
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
tensor
());
}
return
res
;
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
<<
this
->
Name
();
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
mutable_type
()
->
mutable_selected_rows
();
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
mutable_type
()
->
mutable_lod_tensor
()
->
mutable_tensor
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
mutable_tensor
();
default:
LOG
(
FATAL
)
<<
"Getting 'mutable_tensor_desc' is not supported by the "
"type of var "
"%s."
<<
this
->
Name
();
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
VarDesc
::
mutable_tensor_descs
()
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
mutable_tensor
());
}
return
res
;
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
<<
this
->
Name
();
}
}
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/var_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
// convert between std::vector and protobuf repeated.
template
<
typename
T
>
inline
std
::
vector
<
T
>
RepeatedToVector
(
const
google
::
protobuf
::
RepeatedField
<
T
>
&
repeated_field
)
{
std
::
vector
<
T
>
ret
;
ret
.
reserve
(
repeated_field
.
size
());
std
::
copy
(
repeated_field
.
begin
(),
repeated_field
.
end
(),
std
::
back_inserter
(
ret
));
return
ret
;
}
template
<
typename
T
,
typename
RepeatedField
>
inline
void
VectorToRepeated
(
const
std
::
vector
<
T
>
&
vec
,
RepeatedField
*
repeated_field
)
{
repeated_field
->
Clear
();
repeated_field
->
Reserve
(
vec
.
size
());
for
(
const
auto
&
elem
:
vec
)
{
*
repeated_field
->
Add
()
=
elem
;
}
}
// Specialize vector<bool>.
template
<
typename
RepeatedField
>
inline
void
VectorToRepeated
(
const
std
::
vector
<
bool
>
&
vec
,
RepeatedField
*
repeated_field
)
{
repeated_field
->
Clear
();
repeated_field
->
Reserve
(
vec
.
size
());
for
(
auto
elem
:
vec
)
{
*
repeated_field
->
Add
()
=
elem
;
}
}
class
VarDesc
{
public:
explicit
VarDesc
(
const
std
::
string
&
name
)
{
desc_
.
set_name
(
name
);
// TODO(paddle-dev): Why default to lodtensor.
desc_
.
mutable_type
()
->
set_type
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
explicit
VarDesc
(
const
framework
::
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
framework
::
proto
::
VarDesc
*
Proto
()
{
return
&
desc_
;
}
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
void
SetName
(
std
::
string
name
)
{
desc_
.
set_name
(
name
);
}
void
SetTensorDescNum
(
size_t
num
);
size_t
GetTensorDescNum
()
const
;
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
);
void
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
);
std
::
vector
<
int64_t
>
GetShape
()
const
;
std
::
vector
<
std
::
vector
<
int64_t
>>
GetShapes
()
const
;
void
SetDataType
(
framework
::
proto
::
VarType
::
Type
data_type
);
void
SetDataTypes
(
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
multiple_data_type
);
framework
::
proto
::
VarType
::
Type
GetDataType
()
const
;
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
void
SetLoDLevel
(
int32_t
lod_level
);
void
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
);
int32_t
GetLoDLevel
()
const
;
std
::
vector
<
int32_t
>
GetLoDLevels
()
const
;
framework
::
proto
::
VarType
::
Type
GetType
()
const
;
void
SetType
(
framework
::
proto
::
VarType
::
Type
type
);
bool
Persistable
()
const
{
return
desc_
.
persistable
();
}
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
private:
const
framework
::
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
framework
::
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
framework
::
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
framework
::
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
framework
::
proto
::
VarDesc
desc_
;
};
}
// namespace pb
}
// namespace framework
}
// namespace paddle
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
70540d1b
...
@@ -46,8 +46,7 @@ class FcOpLite : public OpLite {
...
@@ -46,8 +46,7 @@ class FcOpLite : public OpLite {
*/
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
@@ -58,8 +57,7 @@ class FcOpLite : public OpLite {
...
@@ -58,8 +57,7 @@ class FcOpLite : public OpLite {
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
Tensor
>
();
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
in_num_col_dims
=
param_
.
in_num_col_dims
=
op_desc
.
GetAttr
(
"in_num_col_dims"
).
get
<
int
>
();
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"in_num_col_dims"
));
CHECK
(
kernel_
);
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
kernel_
->
SetParam
(
param_
);
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
70540d1b
...
@@ -35,8 +35,7 @@ class FeedOp : public OpLite {
...
@@ -35,8 +35,7 @@ class FeedOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
CHECK
(
feed_var
);
CHECK
(
feed_var
);
...
@@ -50,7 +49,7 @@ class FeedOp : public OpLite {
...
@@ -50,7 +49,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
(
"col"
).
get
<
int
>
(
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
浏览文件 @
70540d1b
...
@@ -33,8 +33,7 @@ class FetchOp : public OpLite {
...
@@ -33,8 +33,7 @@ class FetchOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
CHECK
(
x
);
...
@@ -44,7 +43,7 @@ class FetchOp : public OpLite {
...
@@ -44,7 +43,7 @@ class FetchOp : public OpLite {
auto
*
out
=
scope
->
FindVar
(
_out
);
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
(
"col"
).
get
<
int
>
(
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
70540d1b
...
@@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const {
...
@@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
paddle
::
framework
::
OpDesc
&
opdesc
,
bool
IoCopyOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
paddle
::
lite
::
Scope
*
scope
)
{
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
x
=
GetTensor
(
scope
,
x
);
param_
.
x
=
GetTensor
(
scope
,
x
);
...
...
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
70540d1b
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
private:
private:
operators
::
IoCopyParam
param_
;
operators
::
IoCopyParam
param_
;
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
70540d1b
...
@@ -38,8 +38,7 @@ class MulOpLite : public OpLite {
...
@@ -38,8 +38,7 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
@@ -48,8 +47,8 @@ class MulOpLite : public OpLite {
...
@@ -48,8 +47,8 @@ class MulOpLite : public OpLite {
param_
.
y
=
scope
->
FindVar
(
W
)
->
GetMutable
<
Tensor
>
();
param_
.
y
=
scope
->
FindVar
(
W
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
x_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
)
);
param_
.
x_num_col_dims
=
op_desc
.
GetAttr
(
"x_num_col_dims"
).
get
<
int
>
(
);
param_
.
y_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
)
);
param_
.
y_num_col_dims
=
op_desc
.
GetAttr
(
"y_num_col_dims"
).
get
<
int
>
(
);
return
true
;
return
true
;
}
}
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
70540d1b
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return
true
;
return
true
;
}
}
bool
ReluOp
::
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
Tensor
*>
(
param_
.
input
=
const_cast
<
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
param_
.
output
=
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
70540d1b
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
70540d1b
...
@@ -46,18 +46,16 @@ class ScaleOp : public OpLite {
...
@@ -46,18 +46,16 @@ class ScaleOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"scale"
));
param_
.
scale
=
op_desc
.
GetAttr
(
"scale"
).
get
<
float
>
();
param_
.
bias
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"bias"
));
param_
.
bias
=
op_desc
.
GetAttr
(
"bias"
).
get
<
float
>
();
param_
.
bias_after_scale
=
param_
.
bias_after_scale
=
op_desc
.
GetAttr
(
"bias_after_scale"
).
get
<
bool
>
();
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"bias_after_scale"
));
CHECK
(
kernel_
);
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
kernel_
->
SetParam
(
param_
);
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
70540d1b
...
@@ -114,7 +114,8 @@ struct variant {
...
@@ -114,7 +114,8 @@ struct variant {
if
(
type_id
==
typeid
(
T
).
hash_code
())
if
(
type_id
==
typeid
(
T
).
hash_code
())
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
else
else
throw
std
::
bad_cast
();
LOG
(
FATAL
)
<<
"unmatched type get, should be "
<<
type_id
<<
" but get "
<<
typeid
(
T
).
name
();
}
}
~
variant
()
{
helper_t
::
destroy
(
type_id
,
&
data
);
}
~
variant
()
{
helper_t
::
destroy
(
type_id
,
&
data
);
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录