Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e7f32773
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e7f32773
编写于
5月 03, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable optimized model persist
上级
f1ca00a4
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
302 addition
and
53 deletion
+302
-53
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+2
-2
paddle/fluid/lite/api/cxx_api.cc
paddle/fluid/lite/api/cxx_api.cc
+9
-1
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+13
-7
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+10
-1
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-1
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+8
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/generate_program_pass.h
paddle/fluid/lite/core/mir/generate_program_pass.h
+1
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+1
-1
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+1
-1
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+3
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+8
-0
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+4
-9
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+56
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+15
-2
paddle/fluid/lite/core/program_fake_utils.h
paddle/fluid/lite/core/program_fake_utils.h
+2
-4
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+4
-0
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+1
-1
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+3
-1
paddle/fluid/lite/model_parser/model_parser.cc
paddle/fluid/lite/model_parser/model_parser.cc
+69
-0
paddle/fluid/lite/model_parser/model_parser.h
paddle/fluid/lite/model_parser/model_parser.h
+7
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+28
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+20
-20
paddle/fluid/lite/utils/all.h
paddle/fluid/lite/utils/all.h
+1
-0
paddle/fluid/lite/utils/io.h
paddle/fluid/lite/utils/io.h
+33
-0
未找到文件。
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
e7f32773
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
optimizer_lite
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
else
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels
)
endif
()
paddle/fluid/lite/api/cxx_api.cc
浏览文件 @
e7f32773
...
...
@@ -13,7 +13,15 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/platform/port.h"
namespace
paddle
{
namespace
lite
{}
// namespace lite
namespace
lite
{
void
Predictor
::
SaveModel
(
const
std
::
string
&
dir
)
{
MkDirRecursively
(
dir
.
c_str
());
program_
->
PersistModel
(
dir
,
program_desc_
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
e7f32773
...
...
@@ -31,19 +31,19 @@ class Predictor {
void
Build
(
const
std
::
string
&
model_path
,
const
Place
&
prefer_place
,
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
LoadModel
(
model_path
,
scope_
.
get
(),
&
program_desc_
);
Program
program
(
prog
,
scope_
,
valid_places
);
Program
program
(
prog
ram_desc_
,
scope_
,
valid_places
);
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
optimizer_
.
KernelPickPreferPlace
(
prefer_place
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
.
GenRuntimeProgram
();
optimizer
_
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
_
.
GenRuntimeProgram
();
}
void
SaveModel
(
const
std
::
string
&
dir
);
// Get offset-th col of feed.
Tensor
*
GetInput
(
size_t
offset
)
{
auto
*
_feed_list
=
program_
->
exec_scope
()
->
FindVar
(
"feed"
);
...
...
@@ -65,7 +65,13 @@ class Predictor {
void
Run
()
{
program_
->
Run
();
}
const
framework
::
proto
::
ProgramDesc
&
program_desc
()
const
{
return
program_desc_
;
}
private:
Optimizer
optimizer_
;
framework
::
proto
::
ProgramDesc
program_desc_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
unique_ptr
<
RuntimeProgram
>
program_
;
};
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
e7f32773
...
...
@@ -36,7 +36,7 @@ TEST(CXXApi, test) {
});
#endif
predictor
.
Build
(
"/home/chunwei/project
2
/models/model2"
,
predictor
.
Build
(
"/home/chunwei/project/models/model2"
,
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
...
...
@@ -59,6 +59,15 @@ TEST(CXXApi, test) {
LOG
(
INFO
)
<<
"out "
<<
*
out
;
}
TEST
(
CXXApi
,
save_model
)
{
lite
::
Predictor
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
predictor
.
Build
(
"/home/chunwei/project/models/model2"
,
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
predictor
.
SaveModel
(
"./optimized_model"
);
}
}
// namespace lite
}
// namespace paddle
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
e7f32773
...
...
@@ -12,13 +12,13 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite
ops_lite
host_kernels
)
cc_library
(
program_lite SRCS program.cc DEPS op_lite kernel_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
e7f32773
...
...
@@ -96,6 +96,14 @@ class KernelBase {
// Generate the key of the parameter type.
std
::
string
GenParamTypeKey
()
const
;
std
::
string
SerializeKernelType
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
"/"
;
ss
<<
alias_
<<
"/"
;
ss
<<
place
();
return
ss
.
str
();
}
virtual
~
KernelBase
()
=
default
;
void
Torch
()
{}
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
e7f32773
...
...
@@ -22,7 +22,7 @@ namespace mir {
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
Instruc
tTopologicalOrder
())
{
for
(
auto
&
item
:
graph
->
Stm
tTopologicalOrder
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
...
...
paddle/fluid/lite/core/mir/generate_program_pass.h
浏览文件 @
e7f32773
...
...
@@ -31,6 +31,7 @@ class GenerateProgramPass : public ProgramPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
std
::
unique_ptr
<
RuntimeProgram
>
GenProgram
()
{
LOG
(
INFO
)
<<
"insts.size "
<<
insts_
.
size
();
std
::
unique_ptr
<
RuntimeProgram
>
program
(
new
RuntimeProgram
(
std
::
move
(
insts_
)));
return
program
;
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
e7f32773
...
...
@@ -71,7 +71,7 @@ void SSAGraph::SortHelper(
ret
->
push_back
(
node
);
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
Instruc
tTopologicalOrder
()
{
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
Stm
tTopologicalOrder
()
{
CheckBidirectionalConnection
();
std
::
stack
<
mir
::
Node
*>
stack
;
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
e7f32773
...
...
@@ -39,7 +39,7 @@ class SSAGraph : GraphBase {
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
Instruc
tTopologicalOrder
();
std
::
vector
<
mir
::
Node
*>
Stm
tTopologicalOrder
();
// The inputs of the graph.
std
::
vector
<
mir
::
Node
*>
inputs
();
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
e7f32773
...
...
@@ -58,7 +58,7 @@ class VariablePlaceInferencePass : public DebugPass {
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
Instruc
tTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
Stm
tTopologicalOrder
())
{
auto
&
inst
=
x
->
AsStmt
();
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
...
...
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
e7f32773
...
...
@@ -13,8 +13,11 @@
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include <fstream>
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
e7f32773
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -35,6 +36,7 @@ class Optimizer {
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
program_
=
&
program
;
valid_places_
=
valid_places
;
CHECK
(
!
valid_places
.
empty
())
<<
"At least one valid_place should be set"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
...
...
@@ -100,6 +102,11 @@ class Optimizer {
return
*
graph_
;
}
mir
::
SSAGraph
*
mutable_ssa_graph
()
{
CHECK
(
graph_
);
return
graph_
.
get
();
}
protected:
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
...
...
@@ -117,6 +124,7 @@ class Optimizer {
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
vector
<
Place
>
valid_places_
;
lite
::
Scope
*
exec_scope_
{};
Program
*
program_
{};
};
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
e7f32773
...
...
@@ -28,15 +28,10 @@ TEST(Optimizer, test) {
auto
program
=
ProgramFaker
();
std
::
vector
<
Place
>
places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
*
pick_pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
ASSERT_TRUE
(
pick_pass
!=
nullptr
);
pick_pass
->
mutable_kernel_pick_factors
()
->
ConsiderTarget
()
.
ConsiderPrecision
();
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
places
);
optimizer
.
Run
(
std
::
move
(
program
),
places
,
factor
);
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
LOG
(
INFO
)
<<
"num statements "
<<
runtime_program
->
num_instructions
();
}
...
...
@@ -45,4 +40,4 @@ TEST(Optimizer, test) {
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
paddle/fluid/lite/core/program.cc
浏览文件 @
e7f32773
...
...
@@ -13,3 +13,59 @@
// limitations under the License.
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/optimizer.h"
namespace
paddle
{
namespace
lite
{
void
RuntimeProgram
::
PersistModel
(
const
std
::
string
&
path
,
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
// Persist model.
const
std
::
string
model_path
=
path
+
"/__model__"
;
std
::
ofstream
model_ostream
(
model_path
,
std
::
ios_base
::
binary
);
CHECK
(
model_ostream
.
is_open
());
const
std
::
string
pb_str
=
SerializeModelTopology
(
desc
);
model_ostream
.
write
(
pb_str
.
c_str
(),
pb_str
.
size
());
// Persist params.
const
std
::
string
params_path
=
path
+
"/params"
;
CHECK
(
!
IsFileExists
(
params_path
))
<<
"file "
<<
params_path
<<
" exists, can't overwrite"
;
std
::
ofstream
params_ostream
(
params_path
,
std
::
ios_base
::
binary
);
CHECK
(
params_ostream
.
is_open
());
framework
::
proto
::
ProgramDesc
latest_program
;
latest_program
.
ParseFromString
(
pb_str
);
SerializeParams
(
params_ostream
,
latest_program
);
}
std
::
string
RuntimeProgram
::
SerializeModelTopology
(
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
const
std
::
string
kKernelTypeAttr
=
"__@kernel_type_attr@__"
;
auto
program_dummy
=
desc
;
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
for
(
auto
&
node
:
instructions_
)
{
auto
desc_dummy
=
node
.
op
()
->
op_info
()
->
desc
();
OpDesc
desc
(
desc_dummy
);
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializeKernelType
());
// append new opdesc
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
desc
.
Proto
();
}
return
program_dummy
.
SerializeAsString
();
}
void
RuntimeProgram
::
SerializeParams
(
std
::
ostream
&
os
,
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
std
::
vector
<
std
::
string
>
ws
;
for
(
auto
&
item
:
desc
.
blocks
(
0
).
vars
())
{
if
(
item
.
name
()
==
"feed"
||
item
.
name
()
==
"fetch"
)
continue
;
if
(
item
.
persistable
())
{
ws
.
push_back
(
item
.
name
());
}
}
CHECK
(
exec_scope_
);
SerializeTensors
(
os
,
*
exec_scope_
,
ws
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/program.h
浏览文件 @
e7f32773
...
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
@@ -115,6 +116,9 @@ struct Instruction {
return
os
;
}
const
OpLite
*
op
()
const
{
return
op_
.
get
();
}
const
KernelBase
*
kernel
()
const
{
return
kernel_
.
get
();
}
private:
std
::
shared_ptr
<
OpLite
>
op_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
...
...
@@ -128,8 +132,8 @@ class RuntimeProgram {
public:
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
insts
)
:
instructions_
(
std
::
move
(
insts
))
{
if
(
inst
s
.
empty
())
{
LOG
(
ERROR
)
<<
"no instructions"
;
if
(
inst
ructions_
.
empty
())
{
LOG
(
FATAL
)
<<
"no instructions"
;
}
}
...
...
@@ -140,11 +144,20 @@ class RuntimeProgram {
}
}
// Serialize the graph and save to the disk.
void
PersistModel
(
const
std
::
string
&
path
,
const
framework
::
proto
::
ProgramDesc
&
desc
);
void
set_exec_scope
(
lite
::
Scope
*
x
)
{
exec_scope_
=
x
;
}
lite
::
Scope
*
exec_scope
()
{
return
exec_scope_
;
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
protected:
std
::
string
SerializeModelTopology
(
const
framework
::
proto
::
ProgramDesc
&
desc
);
void
SerializeParams
(
std
::
ostream
&
os
,
const
framework
::
proto
::
ProgramDesc
&
desc
);
private:
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
std
::
vector
<
Instruction
>
instructions_
;
...
...
paddle/fluid/lite/core/program_fake_utils.h
浏览文件 @
e7f32773
...
...
@@ -32,21 +32,19 @@ Program FakeProgram() {
auto
b1v
=
program
.
scope
->
Var
(
b1
)
->
GetMutable
<
Tensor
>
();
auto
out1v
=
program
.
scope
->
Var
(
out1
)
->
GetMutable
<
Tensor
>
();
framework
::
OpDesc
desc
;
lite
::
OpDesc
desc
;
desc
.
SetInput
(
"Input"
,
{
x
});
desc
.
SetInput
(
"W"
,
{
w1
});
desc
.
SetInput
(
"Bias"
,
{
b1
});
desc
.
SetOutput
(
"Out"
,
{
out1
});
desc
.
SetType
(
"fc"
);
desc
.
SetAttr
(
"in_num_col_dims"
,
1
);
desc
.
Flush
();
desc
.
SetAttr
<
int
>
(
"in_num_col_dims"
,
1
);
// add to input
program
.
tmp_vars
.
push_back
(
w1
);
program
.
tmp_vars
.
push_back
(
b1
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
fc_op
->
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc_op
->
Attach
(
desc
,
program
.
scope
.
get
());
program
.
ops
.
emplace_back
(
std
::
move
(
fc_op
));
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
e7f32773
...
...
@@ -164,6 +164,8 @@ class TargetWrapper {
};
// This interface should be specified by each kind of target.
using
TargetWrapperHost
=
TargetWrapper
<
TARGET
(
kHost
)
>
;
using
TargetWrapperX86
=
TargetWrapperHost
;
template
<
>
class
TargetWrapper
<
TARGET
(
kHost
)
>
{
public:
...
...
@@ -196,6 +198,8 @@ class TargetWrapper<TARGET(kHost)> {
};
#ifdef LITE_WITH_CUDA
using
TargetWrapperCuda
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
;
// This interface should be specified by each kind of target.
template
<
>
class
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
{
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
e7f32773
...
...
@@ -58,7 +58,7 @@ class Tensor {
const
DDim
&
dims
()
const
{
return
dims_
;
}
const
LoD
&
lod
()
{
return
lod_
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
template
<
typename
T
>
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
e7f32773
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
)
cc_library
(
runtime_lite SRCS runtime.cc
)
cc_test
(
test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
@@ -7,5 +6,8 @@ else()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
compatible_pb_lite
)
add_subdirectory
(
pb
)
paddle/fluid/lite/model_parser/model_parser.cc
浏览文件 @
e7f32773
...
...
@@ -37,6 +37,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
default:
LOG
(
FATAL
)
<<
"unknown data type"
;
}
return
-
1
;
}
void
TensorFromStream
(
std
::
istream
&
is
,
lite
::
Tensor
*
tensor
)
{
...
...
@@ -162,5 +163,73 @@ void LoadModel(const std::string &model_dir, Scope *scope,
}
}
void
TensorToStream
(
std
::
ostream
&
os
,
const
lite
::
Tensor
&
tensor
)
{
{
// the 1st field, uint32_t version
constexpr
uint32_t
version
=
0
;
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
}
{
int
size
=
tensor
.
lod
().
size
();
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
for
(
auto
&
each
:
tensor
.
lod
())
{
size
=
each
.
size
()
*
sizeof
(
each
.
front
());
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
os
.
write
(
reinterpret_cast
<
const
char
*>
(
each
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
{
// the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework
::
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
proto
::
VarType_Type_LOD_TENSOR
);
auto
dims
=
tensor
.
dims
();
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
std
::
copy
(
dims
.
begin
(),
dims
.
end
(),
pb_dims
->
begin
());
int32_t
size
=
desc
.
ByteSize
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
auto
out
=
desc
.
SerializeAsString
();
os
.
write
(
out
.
data
(),
size
);
}
{
// the 3rd field, tensor data
uint64_t
size
=
tensor
.
memory_size
();
CHECK_LT
(
size
,
std
::
numeric_limits
<
std
::
streamsize
>::
max
())
<<
"Index overflow when writing tensor"
;
#ifdef LITE_WITH_CUDA
if
(
tensor
.
target
()
==
TARGET
(
kCUDA
))
{
std
::
unique_ptr
<
char
>
tmp_buffer
(
new
char
[
size
]);
TargetWrapperCuda
::
MemcpySync
(
tmp_buffer
.
get
(),
tensor
.
data
<
char
>
(),
tensor
.
memory_size
(),
IoDirection
::
DtoH
);
os
.
write
(
static_cast
<
const
char
*>
(
tmp_buffer
.
get
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
else
#endif // LITE_WITH_CUDA
{
os
.
write
(
static_cast
<
const
char
*>
(
tensor
.
data
<
void
>
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
}
void
SerializeTensors
(
std
::
ostream
&
os
,
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>
&
vars
)
{
// Store all the persistable vars.
for
(
const
auto
&
_var
:
vars
)
{
auto
*
var
=
scope
.
FindVar
(
_var
);
const
auto
&
tensor
=
var
->
Get
<
lite
::
Tensor
>
();
TensorToStream
(
os
,
tensor
);
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/model_parser.h
浏览文件 @
e7f32773
...
...
@@ -40,5 +40,12 @@ void LoadParam(const std::string& path, Variable* out);
void
LoadModel
(
const
std
::
string
&
model_dir
,
Scope
*
scope
,
framework
::
proto
::
ProgramDesc
*
prog
);
// Serialize tensors to ostream.
void
SerializeTensors
(
std
::
ostream
&
os
,
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
);
// LoDTensor to ostream
void
TensorToStream
(
std
::
ostream
&
os
,
const
lite
::
Tensor
&
tensor
);
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
e7f32773
...
...
@@ -13,3 +13,31 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
}
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
e7f32773
...
...
@@ -120,28 +120,24 @@ class OpDesc {
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find
(
xs
.
begin
(),
xs
.
end
(),
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
switch
(
typeid
(
T
).
hash_code
())
{
case
typeid
(
int
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
break
;
case
typeid
(
float
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupport attr type"
;
size_t
hash
=
typeid
(
T
).
hash_code
();
if
(
hash
==
typeid
(
int
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
}
else
if
(
hash
==
typeid
(
float
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
}
else
if
(
hash
==
typeid
(
bool
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
}
else
{
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
...
...
@@ -229,6 +225,10 @@ class OpDesc {
framework
::
proto
::
OpDesc
desc_
;
};
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
);
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/utils/all.h
浏览文件 @
e7f32773
...
...
@@ -17,5 +17,6 @@
#include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/hash.h"
#include "paddle/fluid/lite/utils/io.h"
#include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h"
paddle/fluid/lite/utils/io.h
0 → 100644
浏览文件 @
e7f32773
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <string>
namespace
paddle
{
namespace
lite
{
static
bool
IsFileExists
(
const
std
::
string
&
path
)
{
std
::
ifstream
file
(
path
);
bool
res
=
file
.
is_open
();
if
(
res
)
{
file
.
close
();
}
return
res
;
}
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录