Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
18959c09
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
18959c09
编写于
4月 19, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make op pointer shared_ptr to support kernel infershape
上级
367a2814
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
30 addition
and
23 deletion
+30
-23
paddle/fluid/lite/api/cxx_api.h
paddle/fluid/lite/api/cxx_api.h
+1
-1
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+4
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+3
-9
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+2
-2
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+4
-2
paddle/fluid/lite/kernels/CMakeLists.txt
paddle/fluid/lite/kernels/CMakeLists.txt
+1
-0
paddle/fluid/lite/kernels/host/CMakeLists.txt
paddle/fluid/lite/kernels/host/CMakeLists.txt
+6
-6
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+6
-0
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+3
-3
未找到文件。
paddle/fluid/lite/api/cxx_api.h
浏览文件 @
18959c09
...
...
@@ -25,7 +25,7 @@ struct Config {};
class
Predictor
{
public:
void
Build
(
const
std
::
string
&
model_path
,
const
std
::
vector
<
OpLite
::
Place
>&
valid_places
)
{
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
executor_
.
get
())
<<
"duplicate build found"
;
framework
::
proto
::
ProgramDesc
prog
;
LoadModel
(
model_path
,
&
scope_
,
&
prog
);
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
18959c09
...
...
@@ -46,6 +46,8 @@ class Node {
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
shared_ptr
<
OpInfo
>
op_info
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
};
struct
Argument
{
...
...
@@ -64,9 +66,11 @@ class Node {
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
,
const
std
::
shared_ptr
<
lite
::
OpInfo
>&
op_info
)
{
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
op_info
=
op_info
;
return
x
;
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
18959c09
...
...
@@ -34,13 +34,7 @@ namespace mir {
struct
Program
{
std
::
list
<
std
::
string
>
tmp_vars
;
std
::
list
<
std
::
string
>
weights
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
{};
};
// Program of kernel.
struct
KernelProgram
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
instructions
;
std
::
list
<
std
::
shared_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
{};
};
...
...
@@ -67,7 +61,7 @@ class SSAGraph : GraphBase {
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
->
op_info
());
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
,
op
->
op_info
());
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
...
...
@@ -122,7 +116,7 @@ class SSAGraph : GraphBase {
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
mir
::
Node
*
RetriveArgument
(
const
std
::
string
&
arg
)
{
mir
::
Node
*
Retri
e
veArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
18959c09
...
...
@@ -63,7 +63,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
auto
*
node
=
graph
->
Retri
e
veArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
...
...
@@ -82,7 +82,7 @@ class VariablePlaceInferencePass : public DebugPass {
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
auto
*
node
=
graph
->
Retri
e
veArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
18959c09
...
...
@@ -27,7 +27,7 @@ namespace lite {
using
KernelFunc
=
std
::
function
<
void
()
>
;
using
KernelFuncCreator
=
std
::
function
<
std
::
unique_ptr
<
KernelFunc
>
()
>
;
class
LiteOpRegistry
final
:
public
Factory
<
OpLite
>
{
class
LiteOpRegistry
final
:
public
Factory
<
OpLite
,
std
::
shared_ptr
<
OpLite
>
>
{
public:
static
LiteOpRegistry
&
Global
()
{
static
auto
*
x
=
new
LiteOpRegistry
;
...
...
@@ -51,7 +51,9 @@ class OpLiteRegistor : public Registor<OpClass> {
};
template
<
TargetType
Target
,
PrecisionType
Precision
>
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>>
;
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
OpKernel
<
Target
,
Precision
>>>
;
class
KernelRegistry
final
{
public:
...
...
paddle/fluid/lite/kernels/CMakeLists.txt
浏览文件 @
18959c09
set
(
lite_kernel_deps type_system kernel_lite op_registry_lite
)
add_subdirectory
(
host
)
add_subdirectory
(
arm
)
add_subdirectory
(
cuda
)
paddle/fluid/lite/kernels/host/CMakeLists.txt
浏览文件 @
18959c09
cc_library
(
fc_compute_host SRCS fc_compute.cc DEPS
tensor_lite
)
cc_library
(
relu_compute_host SRCS relu_compute.cc DEPS
tensor_lite
)
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
tensor_lite
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
tensor_lite
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
tensor_lite
)
cc_library
(
fc_compute_host SRCS fc_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
relu_compute_host SRCS relu_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
mul_compute_host SRCS mul_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
scale_compute_host SRCS scale_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
feed_compute_host SRCS feed_compute.cc DEPS
${
lite_kernel_deps
}
)
cc_library
(
host_kernels DEPS
fc_compute_host
...
...
@@ -10,7 +10,7 @@ cc_library(host_kernels DEPS
mul_compute_host
scale_compute_host
feed_compute_host
DEPS
kernel_lite
DEPS
${
lite_kernel_deps
}
)
cc_test
(
test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite
)
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
18959c09
...
...
@@ -68,4 +68,10 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
MulCompute
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
paddle/fluid/lite/utils/factory.h
浏览文件 @
18959c09
...
...
@@ -33,12 +33,12 @@ namespace lite {
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template
<
typename
ItemType
>
template
<
typename
ItemType
,
typename
ItemTypePtr
>
class
Factory
{
public:
using
item_t
=
ItemType
;
using
self_t
=
Factory
<
item_t
>
;
using
item_ptr_t
=
std
::
unique_ptr
<
item_t
>
;
using
self_t
=
Factory
<
item_t
,
ItemTypePtr
>
;
using
item_ptr_t
=
ItemTypePtr
;
using
creator_t
=
std
::
function
<
item_ptr_t
()
>
;
static
Factory
&
Global
()
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录