Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
70540d1b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
70540d1b
编写于
4月 29, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a new lightweight OpDesc compatible with the original framework::OpDesc
to support mobile
上级
6e19097b
变更
32
隐藏空白更改
内联
并排
Showing
32 changed file
with
929 addition
and
69 deletion
+929
-69
CMakeLists.txt
CMakeLists.txt
+2
-1
cmake/configure.cmake
cmake/configure.cmake
+4
-0
paddle/fluid/lite/api/CMakeLists.txt
paddle/fluid/lite/api/CMakeLists.txt
+3
-3
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+1
-2
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+2
-2
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+2
-16
paddle/fluid/lite/core/op_executor.h
paddle/fluid/lite/core/op_executor.h
+2
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-1
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-5
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+15
-13
paddle/fluid/lite/model_parser/CMakeLists.txt
paddle/fluid/lite/model_parser/CMakeLists.txt
+8
-0
paddle/fluid/lite/model_parser/compatible_pb.cc
paddle/fluid/lite/model_parser/compatible_pb.cc
+15
-0
paddle/fluid/lite/model_parser/compatible_pb.h
paddle/fluid/lite/model_parser/compatible_pb.h
+45
-0
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
+2
-0
paddle/fluid/lite/model_parser/pb/block_desc.cc
paddle/fluid/lite/model_parser/pb/block_desc.cc
+13
-0
paddle/fluid/lite/model_parser/pb/block_desc.h
paddle/fluid/lite/model_parser/pb/block_desc.h
+123
-0
paddle/fluid/lite/model_parser/pb/op_desc.cc
paddle/fluid/lite/model_parser/pb/op_desc.cc
+15
-0
paddle/fluid/lite/model_parser/pb/op_desc.h
paddle/fluid/lite/model_parser/pb/op_desc.h
+234
-0
paddle/fluid/lite/model_parser/pb/program_desc.cc
paddle/fluid/lite/model_parser/pb/program_desc.cc
+13
-0
paddle/fluid/lite/model_parser/pb/program_desc.h
paddle/fluid/lite/model_parser/pb/program_desc.h
+13
-0
paddle/fluid/lite/model_parser/pb/var_desc.cc
paddle/fluid/lite/model_parser/pb/var_desc.cc
+271
-0
paddle/fluid/lite/model_parser/pb/var_desc.h
paddle/fluid/lite/model_parser/pb/var_desc.h
+123
-0
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-4
paddle/fluid/lite/operators/feed_op.cc
paddle/fluid/lite/operators/feed_op.cc
+2
-3
paddle/fluid/lite/operators/fetch_op.cc
paddle/fluid/lite/operators/fetch_op.cc
+2
-3
paddle/fluid/lite/operators/io_copy_op.cc
paddle/fluid/lite/operators/io_copy_op.cc
+1
-2
paddle/fluid/lite/operators/io_copy_op.h
paddle/fluid/lite/operators/io_copy_op.h
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+3
-4
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+4
-6
paddle/fluid/lite/utils/varient.h
paddle/fluid/lite/utils/varient.h
+2
-1
未找到文件。
CMakeLists.txt
浏览文件 @
70540d1b
...
...
@@ -186,7 +186,8 @@ endif()
# for lite
option
(
LITE_WITH_CUDA
"Enable CUDA in lite mode"
ON
)
option
(
LITE_WITH_X86
"Enable X86 in lite mode"
ON
)
option
(
LITE_WITH_X86
"Enable X86 in lite mode"
ON
)
option
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"Enable light-weight framework"
ON
)
include
(
external/threadpool
)
include
(
flags
)
# set paddle compile flags
...
...
cmake/configure.cmake
浏览文件 @
70540d1b
...
...
@@ -171,3 +171,7 @@ endif()
if
(
LITE_WITH_X86
)
add_definitions
(
"-DLITE_WITH_X86"
)
endif
()
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
add_definitions
(
"-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK"
)
endif
()
paddle/fluid/lite/api/CMakeLists.txt
浏览文件 @
70540d1b
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite
)
if
(
LITE_WITH_CUDA
)
cc_library
(
cxx_api_lite_cuda SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host target_wrapper_cuda kernels_cuda
)
nv_test
(
test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda model_parser_lite
)
else
()
cc_library
(
cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host
)
cc_test
(
test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite target_wrapper_host host_kernels
)
endif
()
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
70540d1b
...
...
@@ -33,9 +33,8 @@ class Predictor {
const
std
::
vector
<
Place
>&
valid_places
)
{
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
scope_
.
get
(),
&
prog
);
framework
::
ProgramDesc
prog_desc
(
prog
);
Program
program
(
prog
_desc
,
scope_
,
valid_places
);
Program
program
(
prog
,
scope_
,
valid_places
);
Optimizer
optimizer
;
optimizer
.
KernelPickPreferPlace
(
prefer_place
);
...
...
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
70540d1b
...
...
@@ -5,10 +5,10 @@ cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite
)
cc_library
(
op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite
compatible_pb_lite
)
cc_library
(
op_executor_lite SRCS op_executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework
proto_desc
)
)
cc_library
(
kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite
)
cc_library
(
types_lite SRCS types.cc
)
cc_library
(
type_system SRCS type_system.cc DEPS tensor_lite
)
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
70540d1b
...
...
@@ -65,19 +65,6 @@ void IoComplementPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
}
}
void
UpdateOpdescInputName
(
framework
::
OpDesc
*
desc
,
const
std
::
string
&
old_arg_name
,
const
std
::
string
&
new_arg_name
)
{
for
(
auto
&
item
:
*
desc
->
Proto
()
->
mutable_inputs
())
{
for
(
int
i
=
0
;
i
<
item
.
mutable_arguments
()
->
size
();
i
++
)
{
auto
*
arg
=
item
.
mutable_arguments
(
i
);
if
(
*
arg
==
old_arg_name
)
{
*
arg
=
new_arg_name
;
}
}
}
}
void
IoComplementPass
::
AddIoCopyInst
(
const
Type
&
from
,
const
Type
&
to
,
const
std
::
string
&
var
,
SSAGraph
*
graph
,
Node
*
inst_node
,
...
...
@@ -99,11 +86,10 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
framework
::
OpDesc
op_desc
;
lite
::
OpDesc
op_desc
;
op_desc
.
SetType
(
"io_copy"
);
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
Flush
();
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
AsInstruct
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
...
...
@@ -126,7 +112,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
auto
desc_dummy
=
inst_node
->
AsInstruct
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
framework
::
OpDesc
desc_fake
(
desc_dummy
,
nullptr
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
...
...
paddle/fluid/lite/core/op_executor.h
浏览文件 @
70540d1b
...
...
@@ -23,6 +23,7 @@
namespace
paddle
{
namespace
lite
{
/*
// The Executor is used to run the operators.
class Executor {
public:
...
...
@@ -63,6 +64,7 @@ class RuntimeExecutor {
private:
RuntimeProgram* program_{};
};
*/
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
70540d1b
...
...
@@ -76,7 +76,7 @@ bool OpLite::Run() {
return
true
;
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
OpLite
::
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
scope
);
scope_
=
scope
;
op_info_
.
reset
(
new
OpInfo
);
// Force clean the out-of-date infomation.
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
70540d1b
...
...
@@ -19,12 +19,11 @@
#include <map>
#include <memory>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -82,7 +81,7 @@ class OpLite : public Registry {
virtual
bool
Run
();
// Link the external execution environ to internal context.
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
Attach
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
const
OpInfo
*
op_info
()
const
{
return
op_info_
.
get
();
}
OpInfo
*
mutable_op_info
()
{
return
op_info_
.
get
();
}
...
...
@@ -109,8 +108,7 @@ class OpLite : public Registry {
protected:
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
70540d1b
...
...
@@ -38,10 +38,10 @@ struct Program {
std
::
vector
<
Place
>
valid_places
;
// Runtime scope.
lite
::
Scope
*
exec_scope
{};
const
framework
::
ProgramDesc
desc
;
const
framework
::
proto
::
ProgramDesc
desc
;
explicit
Program
(
const
std
::
shared_ptr
<
Scope
>&
root
)
{
scope
=
root
;
}
Program
(
const
framework
::
ProgramDesc
&
desc
,
Program
(
const
framework
::
proto
::
ProgramDesc
&
desc
,
const
std
::
shared_ptr
<
Scope
>&
root
,
const
std
::
vector
<
Place
>&
valid_places
)
:
scope
(
root
),
valid_places
(
valid_places
),
desc
(
desc
)
{
...
...
@@ -56,24 +56,25 @@ struct Program {
private:
// Build from a program and scope.
void
Build
(
const
framework
::
ProgramDesc
&
program
,
void
Build
(
const
framework
::
proto
::
ProgramDesc
&
program
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
ops
.
empty
())
<<
"Executor duplicate Build found"
;
// Create operators.
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
auto
op_type
=
op_desc
->
Type
();
for
(
const
auto
&
proto_op_desc
:
program
.
blocks
(
0
).
ops
())
{
lite
::
OpDesc
op_desc
(
proto_op_desc
);
auto
op_type
=
op_desc
.
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
ops
.
back
()
->
PickKernel
(
valid_places
);
ops
.
back
()
->
Attach
(
*
op_desc
,
exec_scope
);
ops
.
back
()
->
Attach
(
op_desc
,
exec_scope
);
}
}
// Create temporary variables.
void
PrepareWorkspace
(
const
framework
::
ProgramDesc
&
program
)
{
void
PrepareWorkspace
(
const
framework
::
proto
::
ProgramDesc
&
program
)
{
CHECK
(
!
exec_scope
)
<<
"Duplicate PrepareWorkspace found"
;
exec_scope
=
&
scope
->
NewScope
();
// Create Feed and Fetch var.
...
...
@@ -82,13 +83,14 @@ struct Program {
tmp_vars
.
push_back
(
"feed"
);
tmp_vars
.
push_back
(
"fetch"
);
for
(
auto
var_desc
:
program
.
Block
(
0
).
AllVars
())
{
if
(
!
var_desc
->
Persistable
())
{
tmp_vars
.
push_back
(
var_desc
->
Name
());
exec_scope
->
Var
(
var_desc
->
Name
());
for
(
auto
proto_var_desc
:
program
.
blocks
(
0
).
vars
())
{
lite
::
VarDesc
var_desc
(
proto_var_desc
);
if
(
!
var_desc
.
Persistable
())
{
tmp_vars
.
push_back
(
var_desc
.
Name
());
exec_scope
->
Var
(
var_desc
.
Name
());
}
else
{
if
(
var_desc
->
Name
()
==
"feed"
||
var_desc
->
Name
()
==
"fetch"
)
continue
;
weights
.
push_back
(
var_desc
->
Name
());
if
(
var_desc
.
Name
()
==
"feed"
||
var_desc
.
Name
()
==
"fetch"
)
continue
;
weights
.
push_back
(
var_desc
.
Name
());
}
}
}
...
...
paddle/fluid/lite/model_parser/CMakeLists.txt
浏览文件 @
70540d1b
cc_library
(
model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite
)
cc_library
(
runtime_lite SRCS runtime.cc
)
cc_test
(
test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite
)
if
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite
)
else
()
cc_library
(
compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto
)
endif
(
LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
)
add_subdirectory
(
pb
)
paddle/fluid/lite/model_parser/compatible_pb.cc
0 → 100644
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/compatible_pb.h"
paddle/fluid/lite/model_parser/compatible_pb.h
0 → 100644
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements the interface to manipute the protobuf message. We use
* macros to make a compatible interface with the framework::XXDesc and
* lite::pb::XXDesc.
*/
#include "paddle/fluid/framework/framework.pb.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
#else
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
namespace
paddle
{
namespace
lite
{
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
OpDesc
=
lite
::
pb
::
OpDesc
;
using
VarDesc
=
lite
::
pb
::
VarDesc
;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using
Attribute
=
framework
::
Attribute
;
using
OpDesc
=
framework
::
OpDesc
;
using
VarDesc
=
framework
::
VarDesc
;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/CMakeLists.txt
0 → 100644
浏览文件 @
70540d1b
cc_library
(
var_desc_lite SRCS var_desc.cc DEPS framework_proto
)
cc_library
(
op_desc_lite SRCS op_desc.cc DEPS framework_proto
)
paddle/fluid/lite/model_parser/pb/block_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/block_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <deque>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
lite
{
class
ProgramDesc
;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
class
BlockDesc
{
public:
BlockDesc
(
ProgramDesc
*
prog
,
proto
::
BlockDesc
*
desc
);
BlockDesc
(
const
BlockDesc
&
other
,
proto
::
BlockDesc
*
desc
,
ProgramDesc
*
prog
);
int32_t
ID
()
const
{
return
desc_
->
idx
();
}
int32_t
Parent
()
const
{
return
desc_
->
parent_idx
();
}
int32_t
ForwardBlockID
()
const
{
return
desc_
->
forward_block_idx
();
}
VarDesc
*
Var
(
const
std
::
string
&
name_bytes
);
VarDesc
*
FindVar
(
const
std
::
string
&
name_bytes
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
VarDesc
*
RenameVar
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
&
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
bool
HasVarRecursive
(
const
std
::
string
&
var_name
)
const
;
std
::
set
<
std
::
string
>
LocalVarNames
()
const
{
std
::
set
<
std
::
string
>
var_names
;
for
(
auto
&
var
:
vars_
)
{
var_names
.
insert
(
var
.
first
);
}
return
var_names
;
}
std
::
vector
<
VarDesc
*>
AllVars
()
const
;
BlockDesc
*
ParentBlock
()
const
;
BlockDesc
*
ForwardBlock
()
const
;
void
SetForwardBlockID
(
int32_t
forward_block_id
);
OpDesc
*
AppendOp
();
void
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
);
OpDesc
*
PrependOp
();
void
PrependAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
);
OpDesc
*
InsertOp
(
size_t
index
);
/*
* Only remove op itself,
* do nothing to its input and output variables
*/
void
RemoveOp
(
size_t
s
,
size_t
e
);
void
RemoveOpInternal
(
const
OpDesc
*
op_desc
);
void
RemoveVar
(
const
std
::
string
&
name
)
{
vars_
.
erase
(
name
);
}
std
::
vector
<
OpDesc
*>
AllOps
()
const
;
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
OpDesc
*
Op
(
int
idx
)
const
{
return
ops_
.
at
(
idx
).
get
();
}
void
Flush
();
proto
::
BlockDesc
*
Proto
();
ProgramDesc
*
Program
()
const
{
return
this
->
prog_
;
}
private:
ProgramDesc
*
prog_
;
// not_own
proto
::
BlockDesc
*
desc_
;
// not_own
bool
need_update_
;
std
::
deque
<
std
::
unique_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
VarDesc
>>
vars_
;
DISABLE_COPY_AND_ASSIGN
(
BlockDesc
);
};
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/op_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/op_desc.h"
paddle/fluid/lite/model_parser/pb/op_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/*
* This file implements a light-weight OpDesc like the framework::OpDesc. We
* delete the unnecessary methods, and remove the underlying dependencies, such
* as framework::Operator and boost::varient to make it runnable in mobile.
*/
#include <algorithm>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
using
Attribute
=
variant
<
int
,
float
,
bool
,
std
::
vector
<
std
::
string
>>
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
/*
* The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc.
* Unlike the original one in framework::OpDesc, we remove the local members
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class
OpDesc
{
public:
OpDesc
()
{}
OpDesc
(
const
framework
::
proto
::
OpDesc
&
desc
)
:
desc_
(
desc
)
{}
void
CopyFrom
(
const
OpDesc
&
op_desc
)
{
desc_
=
op_desc
.
ReadonlyProto
();
}
framework
::
proto
::
OpDesc
*
Proto
()
{
return
&
desc_
;
}
const
framework
::
proto
::
OpDesc
&
ReadonlyProto
()
const
{
return
desc_
;
}
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
void
SetType
(
const
std
::
string
&
type
)
{
desc_
.
set_type
(
type
);
}
// Get the arguments of parameter called `param`
std
::
vector
<
std
::
string
>
Input
(
const
std
::
string
&
param
)
const
{
return
GetArguments
(
desc_
.
inputs
(),
param
);
}
std
::
vector
<
std
::
string
>
InputArgumentNames
()
const
{
return
GetArgumentNames
(
desc_
.
inputs
());
}
void
SetInput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
SetArgument
(
desc_
.
mutable_inputs
(),
param
,
args
);
}
std
::
vector
<
std
::
string
>
Output
(
const
std
::
string
&
param
)
const
{
return
GetArguments
(
desc_
.
outputs
(),
param
);
}
std
::
vector
<
std
::
string
>
OutputArgumentNames
()
const
{
return
GetArgumentNames
(
desc_
.
outputs
());
}
void
SetOutput
(
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
SetArgument
(
desc_
.
mutable_outputs
(),
param
,
args
);
}
bool
HasAttr
(
const
std
::
string
&
name
)
const
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
return
it
!=
xs
.
end
();
}
framework
::
proto
::
AttrType
GetAttrType
(
const
std
::
string
&
name
)
const
{
const
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
CHECK
(
it
!=
xs
.
end
());
return
it
->
type
();
}
std
::
vector
<
std
::
string
>
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
res
;
const
auto
&
xs
=
desc_
.
attrs
();
std
::
transform
(
xs
.
begin
(),
xs
.
end
(),
std
::
back_inserter
(
res
),
[](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
();
});
return
res
;
}
template
<
typename
T
>
void
SetAttr
(
const
std
::
string
&
name
,
const
T
&
v
)
{
auto
&
xs
=
*
desc_
.
mutable_attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
if
(
it
==
xs
.
end
())
{
auto
*
attr
=
xs
.
Add
();
attr
->
set_name
(
name
);
it
=
std
::
find
(
xs
.
begin
(),
xs
.
end
(),
name
);
}
switch
(
typeid
(
T
).
hash_code
())
{
case
typeid
(
int
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
INT
);
it
->
set_i
(
v
);
break
;
case
typeid
(
float
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
FLOAT
);
it
->
set_f
(
v
);
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
STRING
);
it
->
set_s
(
v
.
c_str
());
break
;
case
typeid
(
std
::
string
).
hash_code
():
it
->
set_type
(
framework
::
proto
::
BOOLEAN
);
it
->
set_b
(
v
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupport attr type"
;
}
}
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
&
xs
=
desc_
.
attrs
();
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Attr
&
x
)
{
return
x
.
name
()
==
name
;
});
Attribute
res
;
CHECK
(
it
!=
xs
.
end
());
switch
(
it
->
type
())
{
case
framework
::
proto
::
INT
:
res
.
set
<
int
>
(
it
->
i
());
break
;
case
framework
::
proto
::
FLOAT
:
res
.
set
<
float
>
(
it
->
f
());
break
;
case
framework
::
proto
::
STRING
:
res
.
set
<
std
::
string
>
(
it
->
s
());
break
;
case
framework
::
proto
::
BOOLEAN
:
res
.
set
<
bool
>
(
it
->
b
());
break
;
default:
LOG
(
FATAL
)
<<
"unsupported attr type"
;
}
return
res
;
}
private:
std
::
vector
<
std
::
string
>
GetArguments
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
&
xs
,
const
std
::
string
&
param
)
const
{
std
::
vector
<
std
::
string
>
res
;
auto
it
=
std
::
find_if
(
xs
.
begin
(),
xs
.
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Var
&
it
)
{
return
it
.
parameter
()
==
param
;
});
CHECK
(
it
!=
xs
.
end
());
const
auto
&
ys
=
it
->
arguments
();
std
::
transform
(
ys
.
begin
(),
ys
.
end
(),
std
::
back_inserter
(
res
),
[](
const
std
::
string
&
x
)
{
return
x
;
});
return
res
;
}
void
SetArgument
(
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
*
xs
,
const
std
::
string
&
param
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
auto
it
=
std
::
find_if
(
xs
->
begin
(),
xs
->
end
(),
[
&
](
const
framework
::
proto
::
OpDesc_Var
&
it
)
{
return
it
.
parameter
()
==
param
;
});
if
(
it
==
xs
->
end
())
{
auto
*
new_arg
=
xs
->
Add
();
new_arg
->
set_parameter
(
param
);
for
(
const
auto
&
arg
:
args
)
{
*
new_arg
->
mutable_arguments
()
->
Add
()
=
arg
;
}
}
else
{
it
->
mutable_arguments
()
->
Clear
();
for
(
const
auto
&
arg
:
args
)
{
*
it
->
mutable_arguments
()
->
Add
()
=
arg
;
}
}
}
std
::
vector
<
std
::
string
>
GetArgumentNames
(
const
google
::
protobuf
::
RepeatedPtrField
<
framework
::
proto
::
OpDesc_Var
>
&
xs
)
const
{
std
::
vector
<
std
::
string
>
res
;
std
::
transform
(
xs
.
begin
(),
xs
.
end
(),
std
::
back_inserter
(
res
),
[](
const
framework
::
proto
::
OpDesc_Var
&
x
)
{
return
x
.
parameter
();
});
return
res
;
}
private:
framework
::
proto
::
OpDesc
desc_
;
};
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/program_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/program_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/model_parser/pb/var_desc.cc
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/pb/var_desc.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
using
namespace
framework
;
proto
::
VarType
::
Type
VarDesc
::
GetType
()
const
{
return
desc_
.
type
().
type
();
}
void
VarDesc
::
SetType
(
proto
::
VarType
::
Type
type
)
{
desc_
.
mutable_type
()
->
set_type
(
type
);
}
void
VarDesc
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
VectorToRepeated
(
dims
,
mutable_tensor_desc
()
->
mutable_dims
());
}
void
VarDesc
::
SetTensorDescNum
(
size_t
num
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
{
auto
*
lod_tensors_ptr
=
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
();
lod_tensors_ptr
->
Clear
();
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
lod_tensors_ptr
->
Add
();
}
return
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'sub_tensor_number' is not supported by the type "
"of var %s."
<<
this
->
Name
();
}
}
size_t
VarDesc
::
GetTensorDescNum
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
return
desc_
.
type
().
reader
().
lod_tensor_size
();
break
;
default:
LOG
(
FATAL
)
<<
"Getting 'sub_tensor_number' is not supported by the type "
"of var %s."
<<
this
->
Name
();
}
}
void
VarDesc
::
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
)
{
if
(
multiple_dims
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given shapes("
<<
multiple_dims
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_dims
.
size
());
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
tensors
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_dims
.
size
();
++
i
)
{
VectorToRepeated
(
multiple_dims
[
i
],
tensors
[
i
]
->
mutable_dims
());
}
}
std
::
vector
<
int64_t
>
VarDesc
::
GetShape
()
const
{
return
RepeatedToVector
(
tensor_desc
().
dims
());
}
std
::
vector
<
std
::
vector
<
int64_t
>>
VarDesc
::
GetShapes
()
const
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
std
::
vector
<
int64_t
>>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
RepeatedToVector
(
tensor_desc
.
dims
()));
}
return
res
;
}
void
VarDesc
::
SetDataType
(
proto
::
VarType
::
Type
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
void
VarDesc
::
SetDataTypes
(
const
std
::
vector
<
proto
::
VarType
::
Type
>
&
multiple_data_type
)
{
if
(
multiple_data_type
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given data types("
<<
multiple_data_type
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_data_type
.
size
());
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
tensor_descs
=
mutable_tensor_descs
();
for
(
size_t
i
=
0
;
i
<
multiple_data_type
.
size
();
++
i
)
{
tensor_descs
[
i
]
->
set_data_type
(
multiple_data_type
[
i
]);
}
}
proto
::
VarType
::
Type
VarDesc
::
GetDataType
()
const
{
return
tensor_desc
().
data_type
();
}
std
::
vector
<
proto
::
VarType
::
Type
>
VarDesc
::
GetDataTypes
()
const
{
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
descs
=
tensor_descs
();
std
::
vector
<
proto
::
VarType
::
Type
>
res
;
res
.
reserve
(
descs
.
size
());
for
(
const
auto
&
tensor_desc
:
descs
)
{
res
.
push_back
(
tensor_desc
.
data_type
());
}
return
res
;
}
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
desc_
.
mutable_type
()
->
mutable_lod_tensor
()
->
set_lod_level
(
lod_level
);
break
;
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
set_lod_level
(
lod_level
);
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'lod_level' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
void
VarDesc
::
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
)
{
if
(
multiple_lod_level
.
size
()
!=
GetTensorDescNum
())
{
VLOG
(
3
)
<<
"WARNING: The number of given lod_levels("
<<
multiple_lod_level
.
size
()
<<
") doesn't match the existing tensor number("
<<
GetTensorDescNum
()
<<
"). The Reader is going to be reinitialized."
;
SetTensorDescNum
(
multiple_lod_level
.
size
());
}
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
{
size_t
i
=
0
;
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
())
{
lod_tensor
.
set_lod_level
(
multiple_lod_level
[
i
++
]);
}
}
break
;
default:
LOG
(
FATAL
)
<<
"Setting 'lod_levels' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
int32_t
VarDesc
::
GetLoDLevel
()
const
{
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
type
().
lod_tensor
().
lod_level
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
lod_level
();
default:
LOG
(
FATAL
)
<<
"Getting 'lod_level' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
std
::
vector
<
int32_t
>
VarDesc
::
GetLoDLevels
()
const
{
std
::
vector
<
int32_t
>
res
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
res
.
reserve
(
desc_
.
type
().
reader
().
lod_tensor_size
());
for
(
auto
&
lod_tensor
:
desc_
.
type
().
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
lod_level
());
}
return
res
;
break
;
default:
LOG
(
FATAL
)
<<
"Getting 'lod_levels' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
const
proto
::
VarType
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
CHECK
(
desc_
.
has_type
())
<<
"The var's type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
type
().
selected_rows
();
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
type
().
lod_tensor
().
tensor
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
type
().
tensor_array
().
tensor
();
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_desc' is not supported by the type of var %s."
<<
this
->
Name
();
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
VarDesc
::
tensor_descs
()
const
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
for
(
const
auto
&
lod_tensor
:
desc_
.
type
().
reader
().
lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
tensor
());
}
return
res
;
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
<<
this
->
Name
();
}
}
proto
::
VarType
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
SELECTED_ROWS
:
return
desc_
.
mutable_type
()
->
mutable_selected_rows
();
case
proto
::
VarType
::
LOD_TENSOR
:
return
desc_
.
mutable_type
()
->
mutable_lod_tensor
()
->
mutable_tensor
();
case
proto
::
VarType
::
LOD_TENSOR_ARRAY
:
return
desc_
.
mutable_type
()
->
mutable_tensor_array
()
->
mutable_tensor
();
default:
LOG
(
FATAL
)
<<
"Getting 'mutable_tensor_desc' is not supported by the "
"type of var "
"%s."
<<
this
->
Name
();
}
}
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
VarDesc
::
mutable_tensor_descs
()
{
CHECK
(
desc_
.
has_type
())
<<
"The var type hasn't been set."
;
CHECK
(
desc_
.
type
().
has_type
())
<<
"The var type hasn't been set."
;
std
::
vector
<
proto
::
VarType
::
TensorDesc
*>
res
;
res
.
reserve
(
GetTensorDescNum
());
switch
(
desc_
.
type
().
type
())
{
case
proto
::
VarType
::
READER
:
for
(
auto
&
lod_tensor
:
*
desc_
.
mutable_type
()
->
mutable_reader
()
->
mutable_lod_tensor
())
{
res
.
push_back
(
lod_tensor
.
mutable_tensor
());
}
return
res
;
default:
LOG
(
FATAL
)
<<
"Getting 'tensor_descs' is not supported by the type of var "
"%s."
<<
this
->
Name
();
}
}
}
// namespace pb
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/model_parser/pb/var_desc.h
浏览文件 @
70540d1b
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
namespace
paddle
{
namespace
lite
{
namespace
pb
{
// convert between std::vector and protobuf repeated.
template
<
typename
T
>
inline
std
::
vector
<
T
>
RepeatedToVector
(
const
google
::
protobuf
::
RepeatedField
<
T
>
&
repeated_field
)
{
std
::
vector
<
T
>
ret
;
ret
.
reserve
(
repeated_field
.
size
());
std
::
copy
(
repeated_field
.
begin
(),
repeated_field
.
end
(),
std
::
back_inserter
(
ret
));
return
ret
;
}
template
<
typename
T
,
typename
RepeatedField
>
inline
void
VectorToRepeated
(
const
std
::
vector
<
T
>
&
vec
,
RepeatedField
*
repeated_field
)
{
repeated_field
->
Clear
();
repeated_field
->
Reserve
(
vec
.
size
());
for
(
const
auto
&
elem
:
vec
)
{
*
repeated_field
->
Add
()
=
elem
;
}
}
// Specialize vector<bool>.
template
<
typename
RepeatedField
>
inline
void
VectorToRepeated
(
const
std
::
vector
<
bool
>
&
vec
,
RepeatedField
*
repeated_field
)
{
repeated_field
->
Clear
();
repeated_field
->
Reserve
(
vec
.
size
());
for
(
auto
elem
:
vec
)
{
*
repeated_field
->
Add
()
=
elem
;
}
}
class
VarDesc
{
public:
explicit
VarDesc
(
const
std
::
string
&
name
)
{
desc_
.
set_name
(
name
);
// TODO(paddle-dev): Why default to lodtensor.
desc_
.
mutable_type
()
->
set_type
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
explicit
VarDesc
(
const
framework
::
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
framework
::
proto
::
VarDesc
*
Proto
()
{
return
&
desc_
;
}
std
::
string
Name
()
const
{
return
desc_
.
name
();
}
void
SetName
(
std
::
string
name
)
{
desc_
.
set_name
(
name
);
}
void
SetTensorDescNum
(
size_t
num
);
size_t
GetTensorDescNum
()
const
;
void
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
);
void
SetShapes
(
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
multiple_dims
);
std
::
vector
<
int64_t
>
GetShape
()
const
;
std
::
vector
<
std
::
vector
<
int64_t
>>
GetShapes
()
const
;
void
SetDataType
(
framework
::
proto
::
VarType
::
Type
data_type
);
void
SetDataTypes
(
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
&
multiple_data_type
);
framework
::
proto
::
VarType
::
Type
GetDataType
()
const
;
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
()
const
;
void
SetLoDLevel
(
int32_t
lod_level
);
void
SetLoDLevels
(
const
std
::
vector
<
int32_t
>
&
multiple_lod_level
);
int32_t
GetLoDLevel
()
const
;
std
::
vector
<
int32_t
>
GetLoDLevels
()
const
;
framework
::
proto
::
VarType
::
Type
GetType
()
const
;
void
SetType
(
framework
::
proto
::
VarType
::
Type
type
);
bool
Persistable
()
const
{
return
desc_
.
persistable
();
}
void
SetPersistable
(
bool
persistable
)
{
desc_
.
set_persistable
(
persistable
);
}
private:
const
framework
::
proto
::
VarType
::
TensorDesc
&
tensor_desc
()
const
;
std
::
vector
<
framework
::
proto
::
VarType
::
TensorDesc
>
tensor_descs
()
const
;
framework
::
proto
::
VarType
::
TensorDesc
*
mutable_tensor_desc
();
std
::
vector
<
framework
::
proto
::
VarType
::
TensorDesc
*>
mutable_tensor_descs
();
framework
::
proto
::
VarDesc
desc_
;
};
}
// namespace pb
}
// namespace framework
}
// namespace paddle
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
70540d1b
...
...
@@ -46,8 +46,7 @@ class FcOpLite : public OpLite {
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
@@ -58,8 +57,7 @@ class FcOpLite : public OpLite {
param_
.
bias
=
scope
->
FindVar
(
bias
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
in_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"in_num_col_dims"
));
param_
.
in_num_col_dims
=
op_desc
.
GetAttr
(
"in_num_col_dims"
).
get
<
int
>
();
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
...
...
paddle/fluid/lite/operators/feed_op.cc
浏览文件 @
70540d1b
...
...
@@ -35,8 +35,7 @@ class FeedOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
feed_var_name
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
feed_var
=
scope
->
FindVar
(
feed_var_name
);
CHECK
(
feed_var
);
...
...
@@ -50,7 +49,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
(
"col"
).
get
<
int
>
(
);
return
true
;
}
...
...
paddle/fluid/lite/operators/fetch_op.cc
浏览文件 @
70540d1b
...
...
@@ -33,8 +33,7 @@ class FetchOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
{
auto
_x
=
opdesc
.
Input
(
"X"
).
front
();
auto
*
x
=
scope
->
FindVar
(
_x
);
CHECK
(
x
);
...
...
@@ -44,7 +43,7 @@ class FetchOp : public OpLite {
auto
*
out
=
scope
->
FindVar
(
_out
);
param_
.
fetch_list
=
out
->
GetMutable
<
std
::
vector
<
lite
::
Tensor
>>
();
param_
.
col
=
boost
::
get
<
int
>
(
opdesc
.
GetAttr
(
"col"
)
);
param_
.
col
=
opdesc
.
GetAttr
(
"col"
).
get
<
int
>
(
);
return
true
;
}
...
...
paddle/fluid/lite/operators/io_copy_op.cc
浏览文件 @
70540d1b
...
...
@@ -29,8 +29,7 @@ bool IoCopyOp::InferShape() const {
return
true
;
}
bool
IoCopyOp
::
Run
()
{
return
OpLite
::
Run
();
}
bool
IoCopyOp
::
AttachImpl
(
const
paddle
::
framework
::
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
bool
IoCopyOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
paddle
::
lite
::
Scope
*
scope
)
{
auto
x
=
opdesc
.
Input
(
"Input"
).
front
();
auto
out
=
opdesc
.
Output
(
"Out"
).
front
();
param_
.
x
=
GetTensor
(
scope
,
x
);
...
...
paddle/fluid/lite/operators/io_copy_op.h
浏览文件 @
70540d1b
...
...
@@ -31,7 +31,7 @@ class IoCopyOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
private:
operators
::
IoCopyParam
param_
;
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
70540d1b
...
...
@@ -38,8 +38,7 @@ class MulOpLite : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
@@ -48,8 +47,8 @@ class MulOpLite : public OpLite {
param_
.
y
=
scope
->
FindVar
(
W
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
x_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"x_num_col_dims"
)
);
param_
.
y_num_col_dims
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"y_num_col_dims"
)
);
param_
.
x_num_col_dims
=
op_desc
.
GetAttr
(
"x_num_col_dims"
).
get
<
int
>
(
);
param_
.
y_num_col_dims
=
op_desc
.
GetAttr
(
"y_num_col_dims"
).
get
<
int
>
(
);
return
true
;
}
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
70540d1b
...
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return
true
;
}
bool
ReluOp
::
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
70540d1b
...
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
70540d1b
...
...
@@ -46,18 +46,16 @@ class ScaleOp : public OpLite {
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
Tensor
>
();
CHECK
(
scope
->
FindVar
(
out
));
param_
.
output
=
scope
->
FindVar
(
out
)
->
GetMutable
<
Tensor
>
();
param_
.
scale
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"scale"
));
param_
.
bias
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"bias"
));
param_
.
bias_after_scale
=
boost
::
get
<
bool
>
(
op_desc
.
GetAttr
(
"bias_after_scale"
));
param_
.
scale
=
op_desc
.
GetAttr
(
"scale"
).
get
<
float
>
();
param_
.
bias
=
op_desc
.
GetAttr
(
"bias"
).
get
<
float
>
();
param_
.
bias_after_scale
=
op_desc
.
GetAttr
(
"bias_after_scale"
).
get
<
bool
>
();
CHECK
(
kernel_
);
kernel_
->
SetParam
(
param_
);
...
...
paddle/fluid/lite/utils/varient.h
浏览文件 @
70540d1b
...
...
@@ -114,7 +114,8 @@ struct variant {
if
(
type_id
==
typeid
(
T
).
hash_code
())
return
*
reinterpret_cast
<
T
*>
(
&
data
);
else
throw
std
::
bad_cast
();
LOG
(
FATAL
)
<<
"unmatched type get, should be "
<<
type_id
<<
" but get "
<<
typeid
(
T
).
name
();
}
~
variant
()
{
helper_t
::
destroy
(
type_id
,
&
data
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录