Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
97149f31
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97149f31
编写于
4月 04, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
4eedd20f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
118 addition
and
17 deletion
+118
-17
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+7
-1
paddle/fluid/lite/core/kernel_test.cc
paddle/fluid/lite/core/kernel_test.cc
+4
-0
paddle/fluid/lite/core/memory.h
paddle/fluid/lite/core/memory.h
+0
-6
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+32
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+22
-4
paddle/fluid/lite/core/op_lite_test.cc
paddle/fluid/lite/core/op_lite_test.cc
+13
-0
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+36
-2
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+1
-1
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+2
-3
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
97149f31
...
@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
...
@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_op_lite SRCS op_lite_test.cc DEPS op_lite
)
paddle/fluid/lite/core/kernel.h
浏览文件 @
97149f31
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
// An base with virtual functions to unify all the kernel implementation on
// different targets.
class
KernelBase
{
class
KernelBase
{
public:
public:
virtual
void
Run
()
=
0
;
virtual
void
Run
()
=
0
;
...
@@ -45,8 +47,12 @@ class KernelBase {
...
@@ -45,8 +47,12 @@ class KernelBase {
return
param_
.
get
<
Param
>
();
return
param_
.
get
<
Param
>
();
}
}
protected:
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
virtual
~
KernelBase
()
=
default
;
protected:
core
::
any_context_t
context_
;
core
::
any_context_t
context_
;
mutable
operators
::
param_t
param_
;
mutable
operators
::
param_t
param_
;
};
};
...
...
paddle/fluid/lite/core/kernel_test.cc
浏览文件 @
97149f31
...
@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
LOG
(
INFO
)
<<
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
LOG
(
INFO
)
<<
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
}
}
TargetType
target
()
const
override
{
return
TARGET
(
kHost
);
}
PrecisionType
precision
()
const
override
{
return
PRECISION
(
kFloat
);
}
};
};
TEST
(
Kernel
,
test
)
{
TEST
(
Kernel
,
test
)
{
...
...
paddle/fluid/lite/core/memory.h
浏览文件 @
97149f31
...
@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
...
@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
break
;
break
;
case
static_cast
<
int
>
(
TargetType
::
kARM
):
data
=
TargetWrapper
<
TARGET
(
kARM
)
>::
Malloc
(
size
);
break
;
case
static_cast
<
int
>
(
TargetType
::
kHost
):
case
static_cast
<
int
>
(
TargetType
::
kHost
):
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
break
;
...
@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
...
@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
break
;
break
;
case
static_cast
<
int
>
(
TargetType
::
kARM
):
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
break
;
default:
default:
LOG
(
FATAL
)
<<
"Unknown type"
;
LOG
(
FATAL
)
<<
"Unknown type"
;
}
}
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
97149f31
...
@@ -12,4 +12,36 @@
...
@@ -12,4 +12,36 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "op_lite.h"
#include "op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
OpLite
::
CreateKernels
(
const
std
::
vector
<
OpLite
::
Place
>
&
places
)
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
kernels
;
CHECK
(
!
op_type_
.
empty
())
<<
"op_type_ should be set first"
;
for
(
auto
place
:
places
)
{
kernels
.
emplace_back
(
KernelRegistry
::
Global
().
Create
(
op_type_
,
place
.
target
,
place
.
precision
));
}
return
kernels
;
}
void
OpLite
::
PickKernel
(
const
std
::
vector
<
OpLite
::
Place
>
&
valid_places
,
OpLite
::
KernelStrategy
kernel_strategy
)
{
switch
(
kernel_strategy
)
{
case
KernelStrategy
::
kStatic
:
StaticPickKernel
(
valid_places
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupported kernel strategy"
;
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
97149f31
...
@@ -70,7 +70,15 @@ class OpLite : public Registry {
...
@@ -70,7 +70,15 @@ class OpLite : public Registry {
// Inference the outputs' shape.
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
virtual
bool
InferShape
()
const
{
return
true
;
}
// Run this operator.
// Run this operator.
virtual
bool
Run
()
=
0
;
virtual
bool
Run
()
{
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
RecordOutputEvents
();
return
true
;
}
// Build the operator, attach it with the runtime environment.
// Build the operator, attach it with the runtime environment.
virtual
bool
Build
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
Build
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Human-readable information.
// Human-readable information.
...
@@ -79,21 +87,31 @@ class OpLite : public Registry {
...
@@ -79,21 +87,31 @@ class OpLite : public Registry {
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
protected:
protected:
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
=
0
;
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
=
0
;
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
// Wait until all the inputs' events are ready.
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
void
SyncInputEvents
()
{}
// Record the output events, and that will tell all the dependent operators
// some inputs are ready.
void
RecordOutputEvents
()
{}
// Create all the kernels for the valid targets.
// Create all the kernels for the valid targets.
void
CreateKernels
();
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
);
virtual
~
OpLite
()
=
default
;
virtual
~
OpLite
()
=
default
;
protected:
protected:
std
::
unique_ptr
<
OpContext
>
op_context_
;
std
::
unique_ptr
<
OpContext
>
op_context_
;
Place
kernel_place_
;
Place
kernel_place_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/op_lite_test.cc
0 → 100644
浏览文件 @
97149f31
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
TEST
(
OpLite
,
test
)
{
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_registry.h
浏览文件 @
97149f31
...
@@ -59,7 +59,6 @@ class KernelRegistry final {
...
@@ -59,7 +59,6 @@ class KernelRegistry final {
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
>
;
>
;
...
@@ -77,7 +76,6 @@ registries_[0].set<kernel_target_t *>(
...
@@ -77,7 +76,6 @@ registries_[0].set<kernel_target_t *>(
*>(&KernelRegistryForTarget<TARGET(target__), \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
// Currently, just register 2 kernel targets.
INIT_FOR
(
kARM
,
kFloat
);
INIT_FOR
(
kHost
,
kFloat
);
INIT_FOR
(
kHost
,
kFloat
);
#undef INIT_FOR
#undef INIT_FOR
}
}
...
@@ -97,6 +95,42 @@ registries_[0].set<kernel_target_t *>(
...
@@ -97,6 +95,42 @@ registries_[0].set<kernel_target_t *>(
->
Register
(
name
,
std
::
move
(
creator
));
->
Register
(
name
,
std
::
move
(
creator
));
}
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Create
(
op_type
);
}
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
}
switch
(
target
)
{
case
TARGET
(
kHost
):
{
CREATE_KERNEL
(
kHost
);
}
break
;
case
TARGET
(
kX86
):
{
CREATE_KERNEL
(
kX86
);
}
break
;
case
TARGET
(
kCUDA
):
{
CREATE_KERNEL
(
kCUDA
);
}
break
;
default:
CHECK
(
false
)
<<
"not supported kernel place"
;
}
#undef CREATE_KERNEL
}
// Get a kernel registry offset in all the registries.
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
>
static
constexpr
int
GetKernelOffset
()
{
static
constexpr
int
GetKernelOffset
()
{
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
97149f31
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
enum
class
TargetType
{
kHost
=
0
,
kX86
,
kCUDA
,
k
ARM
,
k
LastAsPlaceHolder
};
enum
class
TargetType
{
kHost
=
0
,
kX86
,
kCUDA
,
kLastAsPlaceHolder
};
// Some helper macro to get a specific TargetType.
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
97149f31
...
@@ -21,9 +21,8 @@ namespace paddle {
...
@@ -21,9 +21,8 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
core
{
namespace
core
{
using
any_context_t
=
variant
<
Context
<
TARGET
(
kX86
)
>
,
//
using
any_context_t
=
variant
<
Context
<
TARGET
(
kX86
)
>
,
//
Context
<
TARGET
(
kCUDA
)
>
,
//
Context
<
TARGET
(
kCUDA
)
>
//
Context
<
TARGET
(
kARM
)
>
//
>
;
>
;
}
// namespace core
}
// namespace core
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录