Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4eedd20f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4eedd20f
编写于
4月 03, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make kernel implementation works
上级
f3d1fac2
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
145 addition
and
44 deletion
+145
-44
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+27
-19
paddle/fluid/lite/core/kernel_test.cc
paddle/fluid/lite/core/kernel_test.cc
+44
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+13
-5
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+8
-8
paddle/fluid/lite/core/types.h
paddle/fluid/lite/core/types.h
+31
-0
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+2
-2
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+13
-0
paddle/fluid/lite/x86/target_wrapper.cc
paddle/fluid/lite/x86/target_wrapper.cc
+6
-8
paddle/fluid/lite/x86/target_wrapper.h
paddle/fluid/lite/x86/target_wrapper.h
+0
-2
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
4eedd20f
...
...
@@ -7,3 +7,4 @@ cc_library(op_registry_lite SRCS op_registry.cc)
cc_library
(
scope_lite SRCS scope.cc
)
cc_test
(
test_scope_lite SRCS scope_test.cc DEPS scope_lite
)
cc_test
(
test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86
)
paddle/fluid/lite/core/kernel.h
浏览文件 @
4eedd20f
...
...
@@ -14,46 +14,54 @@
#pragma once
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <map>
#include <string>
#include "context.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
#include "target_wrapper.h"
namespace
paddle
{
namespace
lite
{
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
template
<
TargetType
Target
,
PrecisionType
Precision
>
class
OpKernel
{
class
KernelBase
{
public:
using
context_t
=
Context
<
Target
>
;
using
context_ptr_t
=
std
::
unique_ptr
<
context_t
>
;
OpKernel
()
=
default
;
virtual
void
Run
()
=
0
;
void
SetContext
(
context_ptr_t
&&
ctx
)
{
context_
=
std
::
move
(
ctx
);
}
template
<
TargetType
Target
>
void
SetContext
(
std
::
unique_ptr
<
Context
<
Target
>>&&
ctx
)
{
context_
.
set
<
std
::
unique_ptr
<
Context
<
Target
>>>
(
std
::
move
(
ctx
));
}
void
SetParam
(
operators
::
param_t
param
)
{
param_
=
param
;
}
template
<
typename
T
>
void
SetParam
(
T
param
)
{
param_
.
set
<
T
>
(
param
);
}
template
<
typename
Param
>
Param
&
param
()
const
{
return
param_
.
get
<
Param
>
();
}
protected:
virtual
~
KernelBase
()
=
default
;
core
::
any_context_t
context_
;
mutable
operators
::
param_t
param_
;
};
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
template
<
TargetType
Target
,
PrecisionType
Precision
>
class
OpKernel
:
public
KernelBase
{
public:
virtual
void
Run
()
{
CHECK
(
false
)
<<
"Not Implemented"
;
}
virtual
~
OpKernel
()
=
default
;
OpKernel
()
=
default
;
protected:
context_ptr_t
context_
;
mutable
operators
::
param_t
param_
;
virtual
~
OpKernel
()
=
default
;
};
}
// namespace lite
...
...
paddle/fluid/lite/core/kernel_test.cc
0 → 100644
浏览文件 @
4eedd20f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
core
{
int
test_code
{
-
1
};
class
SomeKernel
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
{
LOG
(
INFO
)
<<
"SomeKernel executed"
;
LOG
(
INFO
)
<<
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
param
<
operators
::
FcParam
>
().
in_num_col_dims
;
}
};
TEST
(
Kernel
,
test
)
{
SomeKernel
kernel
;
operators
::
FcParam
param
;
param
.
in_num_col_dims
=
100
;
kernel
.
SetParam
<
operators
::
FcParam
>
(
param
);
kernel
.
Run
();
ASSERT_EQ
(
test_code
,
100
);
}
}
// namespace core
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
4eedd20f
...
...
@@ -57,7 +57,12 @@ class OpLite : public Registry {
kRuntime
,
};
OpLite
()
{}
struct
Place
{
TargetType
target
{
TARGET
(
kHost
)};
PrecisionType
precision
{
PRECISION
(
kFloat
)};
};
OpLite
()
=
default
;
OpLite
(
std
::
unique_ptr
<
OpContext
>
&&
x
)
:
op_context_
(
std
::
move
(
x
))
{}
// Check the shape.
...
...
@@ -71,12 +76,14 @@ class OpLite : public Registry {
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
const
Place
&
kernel_place
()
const
{
return
kernel_place_
;
}
protected:
// Specify the kernel to run by default.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
TargetTyp
e
>
&
valid_targets
)
=
0
;
// Specify the kernel to run by default.
This will specify the value of
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Plac
e
>
&
valid_targets
)
=
0
;
void
PickKernel
(
const
std
::
vector
<
TargetTyp
e
>
&
valid_places
,
void
PickKernel
(
const
std
::
vector
<
Plac
e
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
// Create all the kernels for the valid targets.
...
...
@@ -86,6 +93,7 @@ class OpLite : public Registry {
protected:
std
::
unique_ptr
<
OpContext
>
op_context_
;
Place
kernel_place_
;
};
}
// namespace lite
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
4eedd20f
...
...
@@ -54,14 +54,14 @@ using KernelRegistryForTarget = Factory<OpKernel<Target, Precision>>;
class
KernelRegistry
final
{
public:
using
any_kernel_registor_t
=
variant
<
KernelRegistryForTarget
<
TargetType
::
kCUDA
,
PrecisionType
::
kFloat
>
*
,
//
KernelRegistryForTarget
<
TargetType
::
kCUDA
,
PrecisionType
::
kInt8
>
*
,
//
KernelRegistryForTarget
<
TargetType
::
kX86
,
PrecisionType
::
kFloat
>
*
,
//
KernelRegistryForTarget
<
TargetType
::
kX86
,
PrecisionType
::
kInt8
>
*
,
//
KernelRegistryForTarget
<
TargetType
::
kARM
,
PrecisionType
::
kFloat
>
*
,
//
KernelRegistryForTarget
<
TargetType
::
kHost
,
PrecisionType
::
kFloat
>
*
//
>
;
using
any_kernel_registor_t
=
variant
<
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kCUDA
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kX86
),
PRECISION
(
kInt8
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kARM
),
PRECISION
(
kFloat
)
>
*
,
//
KernelRegistryForTarget
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
*
//
>
;
KernelRegistry
()
{
/*
...
...
paddle/fluid/lite/core/types.h
0 → 100644
浏览文件 @
4eedd20f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
lite
{
namespace
core
{
using
any_context_t
=
variant
<
Context
<
TARGET
(
kX86
)
>
,
//
Context
<
TARGET
(
kCUDA
)
>
,
//
Context
<
TARGET
(
kARM
)
>
//
>
;
}
// namespace core
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
4eedd20f
...
...
@@ -14,6 +14,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
...
...
@@ -53,8 +54,7 @@ class FcOpLite : public OpLite {
std
::
string
DebugString
()
const
override
{
return
"fc"
;
}
void
StaticPickKernel
(
const
std
::
vector
<
TargetType
>&
valid_targets
)
override
{
}
void
StaticPickKernel
(
const
std
::
vector
<
Place
>&
valid_targets
)
override
{}
private:
mutable
FcParam
param_
;
...
...
paddle/fluid/lite/utils/factory.h
浏览文件 @
4eedd20f
...
...
@@ -19,6 +19,18 @@
namespace
paddle
{
namespace
lite
{
/*
* Factor for any Type creator.
*
* Usage:
*
* struct SomeType;
* // Register a creator.
* Factory<SomeType>::Global().Register("some_key", [] ->
* std::unique_ptr<SomeType> { ... });
* // Retrive a creator.
* auto some_type_instance = Factory<SomeType>::Global().Create("some_key");
*/
template
<
typename
ItemType
>
class
Factory
{
public:
...
...
@@ -55,6 +67,7 @@ class Registor {
public:
Registor
(
std
::
function
<
void
()
>&&
functor
)
{
functor
();
}
// Touch will do nothing.
int
Touch
()
{
return
0
;
}
};
...
...
paddle/fluid/lite/x86/target_wrapper.cc
浏览文件 @
4eedd20f
...
...
@@ -12,22 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "target_wrapper.h"
#include "
paddle/fluid/lite/core/
target_wrapper.h"
#include <algorithm>
namespace
paddle
{
namespace
framework
{
namespace
lite
{
template
<
>
void
TargetWrapper
<
X86
>::
MemcpySync
(
void
*
dst
,
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
std
::
copy_n
(
reinterpret_cast
<
uint8_t
*>
(
src
),
size
,
reinterpret_cast
<
uint8_t
*>
(
dst
));
void
TargetWrapper
<
TARGET
(
kX86
)
>::
MemcpySync
(
void
*
dst
,
void
*
src
,
size_t
size
,
IoDirection
dir
)
{
std
::
copy_n
(
reinterpret_cast
<
uint8_t
*>
(
src
),
size
,
reinterpret_cast
<
uint8_t
*>
(
dst
));
}
template
class
TargetWrapper
<
X86
>;
template
class
TargetWrapper
<
TARGET
(
kX86
)
>;
}
// namespace lite
}
// namespace framework
}
// namespace paddle
paddle/fluid/lite/x86/target_wrapper.h
浏览文件 @
4eedd20f
...
...
@@ -16,9 +16,7 @@
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace
paddle
{
namespace
framework
{
namespace
lite
{
namespace
x86
{}
// namespace x86
}
// namespace lite
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录