Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e7f32773
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e7f32773
编写于
5月 03, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable optimized model persist
上级
f1ca00a4
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
302 addition
and
53 deletion
+302
-53
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+2
-2
paddle/fluid/lite/api/cxx_api.cc
paddle/fluid/lite/api/cxx_api.cc
+9
-1
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+13
-7
paddle/fluid/lite/api/cxx_api_test.cc
paddle/fluid/lite/api/cxx_api_test.cc
+10
-1
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-1
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+8
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/generate_program_pass.h
paddle/fluid/lite/core/mir/generate_program_pass.h
+1
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+1
-1
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+1
-1
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+1
-1
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+3
-0
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+8
-0
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+4
-9
paddle/fluid/lite/core/program.cc
paddle/fluid/lite/core/program.cc
+56
-0
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+15
-2
paddle/fluid/lite/core/program_fake_utils.h
paddle/fluid/lite/core/program_fake_utils.h
+2
-4
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+4
-0
paddle/fluid/lite/core/tensor.h
paddle/fluid/lite/core/tensor.h
+1
-1
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+3
-1
paddle/fluid/lite/model_parser/model_parser.cc
paddle/fluid/lite/model_parser/model_parser.cc
+69
-0
paddle/fluid/lite/model_parser/model_parser.h
paddle/fluid/lite/model_parser/model_parser.h
+7
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+28
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+20
-20
paddle/fluid/lite/utils/all.h
paddle/fluid/lite/utils/all.h
+1
-0
paddle/fluid/lite/utils/io.h
paddle/fluid/lite/utils/io.h
+33
-0
未找到文件。
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
e7f32773
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
optimizer_lite
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
else
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels
)
endif
()
paddle/fluid/lite/api/cxx_api.cc
浏览文件 @
e7f32773
...
...
@@ -13,7 +13,15 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/platform/port.h"
namespace
paddle
{
namespace
lite
{}
// namespace lite
namespace
lite
{
void
Predictor
::
SaveModel
(
const
std
::
string
&
dir
)
{
MkDirRecursively
(
dir
.
c_str
());
program_
->
PersistModel
(
dir
,
program_desc_
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
e7f32773
...
...
@@ -31,19 +31,19 @@ class Predictor {
void
Build
(
const
std
::
string
&
model_path
,
const
Place
&
prefer_place
,
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
LoadModel
(
model_path
,
scope_
.
get
(),
&
program_desc_
);
Program
program
(
prog
,
scope_
,
valid_places
);
Program
program
(
prog
ram_desc_
,
scope_
,
valid_places
);
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
optimizer_
.
KernelPickPreferPlace
(
prefer_place
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
.
GenRuntimeProgram
();
optimizer
_
.
Run
(
std
::
move
(
program
),
valid_places
,
factor
);
program_
=
optimizer
_
.
GenRuntimeProgram
();
}
void
SaveModel
(
const
std
::
string
&
dir
);
// Get offset-th col of feed.
Tensor
*
GetInput
(
size_t
offset
)
{
auto
*
_feed_list
=
program_
->
exec_scope
()
->
FindVar
(
"feed"
);
...
...
@@ -65,7 +65,13 @@ class Predictor {
void
Run
()
{
program_
->
Run
();
}
const
framework
::
proto
::
ProgramDesc
&
program_desc
()
const
{
return
program_desc_
;
}
private:
Optimizer
optimizer_
;
framework
::
proto
::
ProgramDesc
program_desc_
;
std
::
shared_ptr
<
Scope
>
scope_
;
std
::
unique_ptr
<
RuntimeProgram
>
program_
;
};
...
...
paddle/fluid/lite/api/cxx_api_test.cc
浏览文件 @
e7f32773
...
...
@@ -36,7 +36,7 @@ TEST(CXXApi, test) {
});
#endif
predictor
.
Build
(
"/home/chunwei/project
2
/models/model2"
,
predictor
.
Build
(
"/home/chunwei/project/models/model2"
,
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
auto
*
input_tensor
=
predictor
.
GetInput
(
0
);
...
...
@@ -59,6 +59,15 @@ TEST(CXXApi, test) {
LOG
(
INFO
)
<<
"out "
<<
*
out
;
}
TEST
(
CXXApi
,
save_model
)
{
lite
::
Predictor
predictor
;
std
::
vector
<
Place
>
valid_places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
predictor
.
Build
(
"/home/chunwei/project/models/model2"
,
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)},
valid_places
);
predictor
.
SaveModel
(
"./optimized_model"
);
}
}
// namespace lite
}
// namespace paddle
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
e7f32773
...
...
@@ -12,13 +12,13 @@ cc_library(op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_l
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager
)
cc_library
(
program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite
ops_lite
host_kernels
)
cc_library
(
program_lite SRCS program.cc DEPS op_lite kernel_lite
)
cc_library
(
optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
e7f32773
...
...
@@ -96,6 +96,14 @@ class KernelBase {
// Generate the key of the parameter type.
std
::
string
GenParamTypeKey
()
const
;
std
::
string
SerializeKernelType
()
const
{
std
::
stringstream
ss
;
ss
<<
op_type
()
<<
"/"
;
ss
<<
alias_
<<
"/"
;
ss
<<
place
();
return
ss
.
str
();
}
virtual
~
KernelBase
()
=
default
;
void
Torch
()
{}
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
e7f32773
...
...
@@ -22,7 +22,7 @@ namespace mir {
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
Instruc
tTopologicalOrder
())
{
for
(
auto
&
item
:
graph
->
Stm
tTopologicalOrder
())
{
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
...
...
paddle/fluid/lite/core/mir/generate_program_pass.h
浏览文件 @
e7f32773
...
...
@@ -31,6 +31,7 @@ class GenerateProgramPass : public ProgramPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
std
::
unique_ptr
<
RuntimeProgram
>
GenProgram
()
{
LOG
(
INFO
)
<<
"insts.size "
<<
insts_
.
size
();
std
::
unique_ptr
<
RuntimeProgram
>
program
(
new
RuntimeProgram
(
std
::
move
(
insts_
)));
return
program
;
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
e7f32773
...
...
@@ -71,7 +71,7 @@ void SSAGraph::SortHelper(
ret
->
push_back
(
node
);
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
Instruc
tTopologicalOrder
()
{
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
Stm
tTopologicalOrder
()
{
CheckBidirectionalConnection
();
std
::
stack
<
mir
::
Node
*>
stack
;
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
e7f32773
...
...
@@ -39,7 +39,7 @@ class SSAGraph : GraphBase {
mir
::
Node
*
Argument
(
const
std
::
string
&
name
);
std
::
vector
<
mir
::
Node
*>
Instruc
tTopologicalOrder
();
std
::
vector
<
mir
::
Node
*>
Stm
tTopologicalOrder
();
// The inputs of the graph.
std
::
vector
<
mir
::
Node
*>
inputs
();
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
e7f32773
...
...
@@ -58,7 +58,7 @@ class VariablePlaceInferencePass : public DebugPass {
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
Instruc
tTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
Stm
tTopologicalOrder
())
{
auto
&
inst
=
x
->
AsStmt
();
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
...
...
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
e7f32773
...
...
@@ -13,8 +13,11 @@
// limitations under the License.
#include "paddle/fluid/lite/core/optimizer.h"
#include <fstream>
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
...
...
paddle/fluid/lite/core/optimizer.h
浏览文件 @
e7f32773
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/lite/core/mir/type_target_transform_pass.h"
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/model_parser/model_parser.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -35,6 +36,7 @@ class Optimizer {
void
Run
(
Program
&&
program
,
const
std
::
vector
<
Place
>&
valid_places
,
core
::
KernelPickFactor
kernel_pick_factor
,
const
std
::
vector
<
std
::
string
>&
passes
=
{})
{
program_
=
&
program
;
valid_places_
=
valid_places
;
CHECK
(
!
valid_places
.
empty
())
<<
"At least one valid_place should be set"
;
CHECK
(
!
graph_
)
<<
"duplicate optimize found"
;
...
...
@@ -100,6 +102,11 @@ class Optimizer {
return
*
graph_
;
}
mir
::
SSAGraph
*
mutable_ssa_graph
()
{
CHECK
(
graph_
);
return
graph_
.
get
();
}
protected:
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
...
...
@@ -117,6 +124,7 @@ class Optimizer {
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
vector
<
Place
>
valid_places_
;
lite
::
Scope
*
exec_scope_
{};
Program
*
program_
{};
};
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
e7f32773
...
...
@@ -28,15 +28,10 @@ TEST(Optimizer, test) {
auto
program
=
ProgramFaker
();
std
::
vector
<
Place
>
places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
*
pick_pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
StaticKernelPickPass
>
(
"static_kernel_pick_pass"
);
ASSERT_TRUE
(
pick_pass
!=
nullptr
);
pick_pass
->
mutable_kernel_pick_factors
()
->
ConsiderTarget
()
.
ConsiderPrecision
();
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
optimizer
.
Run
(
std
::
move
(
program
),
places
);
optimizer
.
Run
(
std
::
move
(
program
),
places
,
factor
);
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
LOG
(
INFO
)
<<
"num statements "
<<
runtime_program
->
num_instructions
();
}
...
...
@@ -45,4 +40,4 @@ TEST(Optimizer, test) {
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
def
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
paddle/fluid/lite/core/program.cc
浏览文件 @
e7f32773
...
...
@@ -13,3 +13,59 @@
// limitations under the License.
#include "paddle/fluid/lite/core/program.h"
#include "paddle/fluid/lite/core/optimizer.h"
namespace
paddle
{
namespace
lite
{
void
RuntimeProgram
::
PersistModel
(
const
std
::
string
&
path
,
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
// Persist model.
const
std
::
string
model_path
=
path
+
"/__model__"
;
std
::
ofstream
model_ostream
(
model_path
,
std
::
ios_base
::
binary
);
CHECK
(
model_ostream
.
is_open
());
const
std
::
string
pb_str
=
SerializeModelTopology
(
desc
);
model_ostream
.
write
(
pb_str
.
c_str
(),
pb_str
.
size
());
// Persist params.
const
std
::
string
params_path
=
path
+
"/params"
;
CHECK
(
!
IsFileExists
(
params_path
))
<<
"file "
<<
params_path
<<
" exists, can't overwrite"
;
std
::
ofstream
params_ostream
(
params_path
,
std
::
ios_base
::
binary
);
CHECK
(
params_ostream
.
is_open
());
framework
::
proto
::
ProgramDesc
latest_program
;
latest_program
.
ParseFromString
(
pb_str
);
SerializeParams
(
params_ostream
,
latest_program
);
}
std
::
string
RuntimeProgram
::
SerializeModelTopology
(
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
const
std
::
string
kKernelTypeAttr
=
"__@kernel_type_attr@__"
;
auto
program_dummy
=
desc
;
program_dummy
.
mutable_blocks
(
0
)
->
clear_ops
();
for
(
auto
&
node
:
instructions_
)
{
auto
desc_dummy
=
node
.
op
()
->
op_info
()
->
desc
();
OpDesc
desc
(
desc_dummy
);
desc
.
SetAttr
(
kKernelTypeAttr
,
node
.
kernel
()
->
SerializeKernelType
());
// append new opdesc
*
program_dummy
.
mutable_blocks
(
0
)
->
add_ops
()
=
*
desc
.
Proto
();
}
return
program_dummy
.
SerializeAsString
();
}
void
RuntimeProgram
::
SerializeParams
(
std
::
ostream
&
os
,
const
framework
::
proto
::
ProgramDesc
&
desc
)
{
std
::
vector
<
std
::
string
>
ws
;
for
(
auto
&
item
:
desc
.
blocks
(
0
).
vars
())
{
if
(
item
.
name
()
==
"feed"
||
item
.
name
()
==
"fetch"
)
continue
;
if
(
item
.
persistable
())
{
ws
.
push_back
(
item
.
name
());
}
}
CHECK
(
exec_scope_
);
SerializeTensors
(
os
,
*
exec_scope_
,
ws
);
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/program.h
浏览文件 @
e7f32773
...
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
...
...
@@ -115,6 +116,9 @@ struct Instruction {
return
os
;
}
const
OpLite
*
op
()
const
{
return
op_
.
get
();
}
const
KernelBase
*
kernel
()
const
{
return
kernel_
.
get
();
}
private:
std
::
shared_ptr
<
OpLite
>
op_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
...
...
@@ -128,8 +132,8 @@ class RuntimeProgram {
public:
explicit
RuntimeProgram
(
std
::
vector
<
Instruction
>&&
insts
)
:
instructions_
(
std
::
move
(
insts
))
{
if
(
inst
s
.
empty
())
{
LOG
(
ERROR
)
<<
"no instructions"
;
if
(
inst
ructions_
.
empty
())
{
LOG
(
FATAL
)
<<
"no instructions"
;
}
}
...
...
@@ -140,11 +144,20 @@ class RuntimeProgram {
}
}
// Serialize the graph and save to the disk.
void
PersistModel
(
const
std
::
string
&
path
,
const
framework
::
proto
::
ProgramDesc
&
desc
);
void
set_exec_scope
(
lite
::
Scope
*
x
)
{
exec_scope_
=
x
;
}
lite
::
Scope
*
exec_scope
()
{
return
exec_scope_
;
}
size_t
num_instructions
()
const
{
return
instructions_
.
size
();
}
protected:
std
::
string
SerializeModelTopology
(
const
framework
::
proto
::
ProgramDesc
&
desc
);
void
SerializeParams
(
std
::
ostream
&
os
,
const
framework
::
proto
::
ProgramDesc
&
desc
);
private:
RuntimeProgram
(
const
RuntimeProgram
&
)
=
delete
;
std
::
vector
<
Instruction
>
instructions_
;
...
...
paddle/fluid/lite/core/program_fake_utils.h
浏览文件 @
e7f32773
...
...
@@ -32,21 +32,19 @@ Program FakeProgram() {
auto
b1v
=
program
.
scope
->
Var
(
b1
)
->
GetMutable
<
Tensor
>
();
auto
out1v
=
program
.
scope
->
Var
(
out1
)
->
GetMutable
<
Tensor
>
();
framework
::
OpDesc
desc
;
lite
::
OpDesc
desc
;
desc
.
SetInput
(
"Input"
,
{
x
});
desc
.
SetInput
(
"W"
,
{
w1
});
desc
.
SetInput
(
"Bias"
,
{
b1
});
desc
.
SetOutput
(
"Out"
,
{
out1
});
desc
.
SetType
(
"fc"
);
desc
.
SetAttr
(
"in_num_col_dims"
,
1
);
desc
.
Flush
();
desc
.
SetAttr
<
int
>
(
"in_num_col_dims"
,
1
);
// add to input
program
.
tmp_vars
.
push_back
(
w1
);
program
.
tmp_vars
.
push_back
(
b1
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
fc_op
->
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc_op
->
Attach
(
desc
,
program
.
scope
.
get
());
program
.
ops
.
emplace_back
(
std
::
move
(
fc_op
));
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
e7f32773
...
...
@@ -164,6 +164,8 @@ class TargetWrapper {
};
// This interface should be specified by each kind of target.
using
TargetWrapperHost
=
TargetWrapper
<
TARGET
(
kHost
)
>
;
using
TargetWrapperX86
=
TargetWrapperHost
;
template
<
>
class
TargetWrapper
<
TARGET
(
kHost
)
>
{
public:
...
...
@@ -196,6 +198,8 @@ class TargetWrapper<TARGET(kHost)> {
};
#ifdef LITE_WITH_CUDA
using
TargetWrapperCuda
=
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
;
// This interface should be specified by each kind of target.
template
<
>
class
TargetWrapper
<
TARGET
(
kCUDA
),
cudaStream_t
,
cudaEvent_t
>
{
...
...
paddle/fluid/lite/core/tensor.h
浏览文件 @
e7f32773
...
...
@@ -58,7 +58,7 @@ class Tensor {
const
DDim
&
dims
()
const
{
return
dims_
;
}
const
LoD
&
lod
()
{
return
lod_
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
LoD
*
mutable_lod
()
{
return
&
lod_
;
}
template
<
typename
T
>
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
e7f32773
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
)
cc_library
(
runtime_lite SRCS runtime.cc
)
cc_test
(
test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
...
...
@@ -7,5 +6,8 @@ else()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
compatible_pb_lite
)
add_subdirectory
(
pb
)
paddle/fluid/lite/model_parser/model_parser.cc
浏览文件 @
e7f32773
...
...
@@ -37,6 +37,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
default:
LOG
(
FATAL
)
<<
"unknown data type"
;
}
return
-
1
;
}
void
TensorFromStream
(
std
::
istream
&
is
,
lite
::
Tensor
*
tensor
)
{
...
...
@@ -162,5 +163,73 @@ void LoadModel(const std::string &model_dir, Scope *scope,
}
}
void
TensorToStream
(
std
::
ostream
&
os
,
const
lite
::
Tensor
&
tensor
)
{
{
// the 1st field, uint32_t version
constexpr
uint32_t
version
=
0
;
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
version
),
sizeof
(
version
));
}
{
int
size
=
tensor
.
lod
().
size
();
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
for
(
auto
&
each
:
tensor
.
lod
())
{
size
=
each
.
size
()
*
sizeof
(
each
.
front
());
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
os
.
write
(
reinterpret_cast
<
const
char
*>
(
each
.
data
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
{
// the 2nd field, tensor description
// int32_t size
// void* protobuf message
framework
::
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
proto
::
VarType_Type_LOD_TENSOR
);
auto
dims
=
tensor
.
dims
();
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
std
::
copy
(
dims
.
begin
(),
dims
.
end
(),
pb_dims
->
begin
());
int32_t
size
=
desc
.
ByteSize
();
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
size
),
sizeof
(
size
));
auto
out
=
desc
.
SerializeAsString
();
os
.
write
(
out
.
data
(),
size
);
}
{
// the 3rd field, tensor data
uint64_t
size
=
tensor
.
memory_size
();
CHECK_LT
(
size
,
std
::
numeric_limits
<
std
::
streamsize
>::
max
())
<<
"Index overflow when writing tensor"
;
#ifdef LITE_WITH_CUDA
if
(
tensor
.
target
()
==
TARGET
(
kCUDA
))
{
std
::
unique_ptr
<
char
>
tmp_buffer
(
new
char
[
size
]);
TargetWrapperCuda
::
MemcpySync
(
tmp_buffer
.
get
(),
tensor
.
data
<
char
>
(),
tensor
.
memory_size
(),
IoDirection
::
DtoH
);
os
.
write
(
static_cast
<
const
char
*>
(
tmp_buffer
.
get
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
else
#endif // LITE_WITH_CUDA
{
os
.
write
(
static_cast
<
const
char
*>
(
tensor
.
data
<
void
>
()),
static_cast
<
std
::
streamsize
>
(
size
));
}
}
}
void
SerializeTensors
(
std
::
ostream
&
os
,
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>
&
vars
)
{
// Store all the persistable vars.
for
(
const
auto
&
_var
:
vars
)
{
auto
*
var
=
scope
.
FindVar
(
_var
);
const
auto
&
tensor
=
var
->
Get
<
lite
::
Tensor
>
();
TensorToStream
(
os
,
tensor
);
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/model_parser.h
浏览文件 @
e7f32773
...
...
@@ -40,5 +40,12 @@ void LoadParam(const std::string& path, Variable* out);
void
LoadModel
(
const
std
::
string
&
model_dir
,
Scope
*
scope
,
framework
::
proto
::
ProgramDesc
*
prog
);
// Serialize tensors to ostream.
void
SerializeTensors
(
std
::
ostream
&
os
,
const
lite
::
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
vars
);
// LoDTensor to ostream
void
TensorToStream
(
std
::
ostream
&
os
,
const
lite
::
Tensor
&
tensor
);
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
e7f32773
...
...
@@ -13,3 +13,31 @@
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
}
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
e7f32773
...
...
@@ -120,28 +120,24 @@ class OpDesc {
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find
(
xs
.
begin
(),
xs
.
end
(),
name
);
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
}
switch
(
typeid
(
T
).
hash_code
())
{
case
typeid
(
int
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
break
;
case
typeid
(
float
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupport attr type"
;
size_t
hash
=
typeid
(
T
).
hash_code
();
if
(
hash
==
typeid
(
int
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
}
else
if
(
hash
==
typeid
(
float
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
}
else
if
(
hash
==
typeid
(
bool
).
hash_code
())
{
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
}
else
{
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
...
...
@@ -229,6 +225,10 @@ class OpDesc {
framework
::
proto
::
OpDesc
desc_
;
};
template
<
>
void
OpDesc
::
SetAttr
<
std
::
string
>
(
const
std
::
string
&
name
,
const
std
::
string
&
v
);
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/utils/all.h
浏览文件 @
e7f32773
...
...
@@ -17,5 +17,6 @@
#include "paddle/fluid/lite/utils/check.h"
#include "paddle/fluid/lite/utils/factory.h"
#include "paddle/fluid/lite/utils/hash.h"
#include "paddle/fluid/lite/utils/io.h"
#include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h"
paddle/fluid/lite/utils/io.h
0 → 100644
浏览文件 @
e7f32773
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <string>
namespace
paddle
{
namespace
lite
{
static
bool
IsFileExists
(
const
std
::
string
&
path
)
{
std
::
ifstream
file
(
path
);
bool
res
=
file
.
is_open
();
if
(
res
)
{
file
.
close
();
}
return
res
;
}
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录