Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
97149f31
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97149f31
编写于
4月 04, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
4eedd20f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
118 addition
and
17 deletion
+118
-17
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+7
-1
paddle/fluid/lite/core/kernel_test.cc
paddle/fluid/lite/core/kernel_test.cc
+4
-0
paddle/fluid/lite/core/memory.h
paddle/fluid/lite/core/memory.h
+0
-6
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+32
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+22
-4
paddle/fluid/lite/core/op_lite_test.cc
paddle/fluid/lite/core/op_lite_test.cc
+13
-0
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+36
-2
paddle/fluid/lite/core/target_wrapper.h
paddle/fluid/lite/core/target_wrapper.h
+1
-1
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+2
-3
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
97149f31
...
@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
...
@@ -8,3 +8,4 @@ cc_library(scope_lite SRCS scope.cc)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
cc_test
(
test_op_lite SRCS op_lite_test.cc DEPS op_lite
)
paddle/fluid/lite/core/kernel.h
浏览文件 @
97149f31
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
// An base with virtual functions to unify all the kernel implementation on
// different targets.
class
KernelBase
{
class
KernelBase
{
public:
public:
virtual
void
Run
()
=
0
;
virtual
void
Run
()
=
0
;
...
@@ -45,8 +47,12 @@ class KernelBase {
...
@@ -45,8 +47,12 @@ class KernelBase {
return
param_
.
get
<
Param
>
();
return
param_
.
get
<
Param
>
();
}
}
protected:
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
virtual
~
KernelBase
()
=
default
;
protected:
core
::
any_context_t
context_
;
core
::
any_context_t
context_
;
mutable
operators
::
param_t
param_
;
mutable
operators
::
param_t
param_
;
};
};
...
...
paddle/fluid/lite/core/kernel_test.cc
浏览文件 @
97149f31
...
@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -28,6 +28,10 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
LOG
(
INFO
)
<<
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
LOG
(
INFO
)
<<
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
}
}
TargetType
target
()
const
override
{
return
TARGET
(
kHost
);
}
PrecisionType
precision
()
const
override
{
return
PRECISION
(
kFloat
);
}
};
};
TEST
(
Kernel
,
test
)
{
TEST
(
Kernel
,
test
)
{
...
...
paddle/fluid/lite/core/memory.h
浏览文件 @
97149f31
...
@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
...
@@ -28,9 +28,6 @@ static void* TargetMalloc(TargetType target, size_t size) {
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kCUDA
)
>::
Malloc
(
size
);
break
;
break
;
case
static_cast
<
int
>
(
TargetType
::
kARM
):
data
=
TargetWrapper
<
TARGET
(
kARM
)
>::
Malloc
(
size
);
break
;
case
static_cast
<
int
>
(
TargetType
::
kHost
):
case
static_cast
<
int
>
(
TargetType
::
kHost
):
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
data
=
TargetWrapper
<
TARGET
(
kHost
)
>::
Malloc
(
size
);
break
;
break
;
...
@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
...
@@ -48,9 +45,6 @@ static void TargetFree(TargetType target, void* data) {
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
case
static_cast
<
int
>
(
TargetType
::
kCUDA
):
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
break
;
break
;
case
static_cast
<
int
>
(
TargetType
::
kARM
):
TargetWrapper
<
TARGET
(
kX86
)
>::
Free
(
data
);
break
;
default:
default:
LOG
(
FATAL
)
<<
"Unknown type"
;
LOG
(
FATAL
)
<<
"Unknown type"
;
}
}
...
...
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
97149f31
...
@@ -12,4 +12,36 @@
...
@@ -12,4 +12,36 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "op_lite.h"
#include "op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
OpLite
::
CreateKernels
(
const
std
::
vector
<
OpLite
::
Place
>
&
places
)
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
kernels
;
CHECK
(
!
op_type_
.
empty
())
<<
"op_type_ should be set first"
;
for
(
auto
place
:
places
)
{
kernels
.
emplace_back
(
KernelRegistry
::
Global
().
Create
(
op_type_
,
place
.
target
,
place
.
precision
));
}
return
kernels
;
}
void
OpLite
::
PickKernel
(
const
std
::
vector
<
OpLite
::
Place
>
&
valid_places
,
OpLite
::
KernelStrategy
kernel_strategy
)
{
switch
(
kernel_strategy
)
{
case
KernelStrategy
::
kStatic
:
StaticPickKernel
(
valid_places
);
break
;
default:
LOG
(
FATAL
)
<<
"unsupported kernel strategy"
;
}
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
97149f31
...
@@ -70,7 +70,15 @@ class OpLite : public Registry {
...
@@ -70,7 +70,15 @@ class OpLite : public Registry {
// Inference the outputs' shape.
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
virtual
bool
InferShape
()
const
{
return
true
;
}
// Run this operator.
// Run this operator.
virtual
bool
Run
()
=
0
;
virtual
bool
Run
()
{
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
RecordOutputEvents
();
return
true
;
}
// Build the operator, attach it with the runtime environment.
// Build the operator, attach it with the runtime environment.
virtual
bool
Build
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
virtual
bool
Build
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Human-readable information.
// Human-readable information.
...
@@ -79,21 +87,31 @@ class OpLite : public Registry {
...
@@ -79,21 +87,31 @@ class OpLite : public Registry {
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
protected:
protected:
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
=
0
;
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
=
0
;
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
// Wait until all the inputs' events are ready.
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
void
SyncInputEvents
()
{}
// Record the output events, and that will tell all the dependent operators
// some inputs are ready.
void
RecordOutputEvents
()
{}
// Create all the kernels for the valid targets.
// Create all the kernels for the valid targets.
void
CreateKernels
();
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
);
virtual
~
OpLite
()
=
default
;
virtual
~
OpLite
()
=
default
;
protected:
protected:
std
::
unique_ptr
<
OpContext
>
op_context_
;
std
::
unique_ptr
<
OpContext
>
op_context_
;
Place
kernel_place_
;
Place
kernel_place_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/op_lite_test.cc
0 → 100644
浏览文件 @
97149f31
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
TEST
(
OpLite
,
test
)
{
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_registry.h
浏览文件 @
97149f31
...
@@ -59,7 +59,6 @@ class KernelRegistry final {
...
@@ -59,7 +59,6 @@ class KernelRegistry final {
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
>
;
>
;
...
@@ -77,7 +76,6 @@ registries_[0].set<kernel_target_t *>(
...
@@ -77,7 +76,6 @@ registries_[0].set<kernel_target_t *>(
*>(&KernelRegistryForTarget<TARGET(target__), \
*>(&KernelRegistryForTarget<TARGET(target__), \
PRECISION(precision__)>::Global());
PRECISION(precision__)>::Global());
// Currently, just register 2 kernel targets.
// Currently, just register 2 kernel targets.
INIT_FOR
(
kARM
,
kFloat
);
INIT_FOR
(
kHost
,
kFloat
);
INIT_FOR
(
kHost
,
kFloat
);
#undef INIT_FOR
#undef INIT_FOR
}
}
...
@@ -97,6 +95,42 @@ registries_[0].set<kernel_target_t *>(
...
@@ -97,6 +95,42 @@ registries_[0].set<kernel_target_t *>(
->
Register
(
name
,
std
::
move
(
creator
));
->
Register
(
name
,
std
::
move
(
creator
));
}
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Create
(
op_type
);
}
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
return Create<TARGET(target__), PRECISION(kFloat)>(op_type); \
default: \
CHECK(false) << "not supported kernel place yet"; \
}
switch
(
target
)
{
case
TARGET
(
kHost
):
{
CREATE_KERNEL
(
kHost
);
}
break
;
case
TARGET
(
kX86
):
{
CREATE_KERNEL
(
kX86
);
}
break
;
case
TARGET
(
kCUDA
):
{
CREATE_KERNEL
(
kCUDA
);
}
break
;
default:
CHECK
(
false
)
<<
"not supported kernel place"
;
}
#undef CREATE_KERNEL
}
// Get a kernel registry offset in all the registries.
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
>
static
constexpr
int
GetKernelOffset
()
{
static
constexpr
int
GetKernelOffset
()
{
...
...
paddle/fluid/lite/core/target_wrapper.h
浏览文件 @
97149f31
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
enum
class
TargetType
{
kHost
=
0
,
kX86
,
kCUDA
,
k
ARM
,
k
LastAsPlaceHolder
};
enum
class
TargetType
{
kHost
=
0
,
kX86
,
kCUDA
,
kLastAsPlaceHolder
};
// Some helper macro to get a specific TargetType.
// Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
...
...
paddle/fluid/lite/core/types.h
浏览文件 @
97149f31
...
@@ -21,9 +21,8 @@ namespace paddle {
...
@@ -21,9 +21,8 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
core
{
namespace
core
{
using
any_context_t
=
variant
<
Context
<
TARGET
(
kX86
)
>
,
//
using
any_context_t
=
variant
<
Context
<
TARGET
(
kX86
)
>
,
//
Context
<
TARGET
(
kCUDA
)
>
,
//
Context
<
TARGET
(
kCUDA
)
>
//
Context
<
TARGET
(
kARM
)
>
//
>
;
>
;
}
// namespace core
}
// namespace core
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录