Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bac1426d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
bac1426d
编写于
7月 14, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add_op kernel implementation
上级
6f2eba3e
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
97 addition
and
38 deletion
+97
-38
paddle/framework/operator.cc
paddle/framework/operator.cc
+12
-0
paddle/framework/operator.h
paddle/framework/operator.h
+44
-23
paddle/framework/tensor.h
paddle/framework/tensor.h
+15
-1
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+6
-5
paddle/operators/add_op.cu
paddle/operators/add_op.cu
+5
-3
paddle/operators/add_op.h
paddle/operators/add_op.h
+15
-6
未找到文件。
paddle/framework/operator.cc
浏览文件 @
bac1426d
...
...
@@ -17,6 +17,18 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
template
<
>
DeviceType
*
KernelContext
::
get_eigen_device
<
CPUPlace
>
()
{
return
device_context_
.
get_eigen_device
<
DeviceType
>
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
DeviceType
*
KernelContext
::
get_eigen_device
<
GPUPlace
>
()
{
return
device_context_
.
get_eigen_device
<
DeviceType
>
();
}
#endif
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
ss
<<
"=================
\n
"
;
...
...
paddle/framework/operator.h
浏览文件 @
bac1426d
...
...
@@ -29,6 +29,21 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
template
<
typename
T
>
struct
EigenDeviceConverter
;
template
<
>
struct
EigenDeviceConverter
<
CPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
DefaultDevice
;
};
#ifndef PADDLE_ONLY_CPU
template
<
>
struct
EigenDeviceConverter
<
GPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
GpuDevice
;
};
#endif
class
OperatorBase
;
/**
...
...
@@ -72,15 +87,13 @@ class OperatorBase {
AttributeMap
attrs_
;
};
class
OpKernel
{
public:
/**
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class
KernelContext
{
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
...
...
@@ -94,11 +107,19 @@ class OpKernel {
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
platform
::
DeviceContext
&
device_context
()
const
{
return
device_context_
;
}
template
<
typename
PlaceType
,
typename
DeviceType
=
EigenDeviceConverter
<
PlaceType
>
::
EigenDeviceType
>
DeviceType
*
get_eigen_device
();
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
OpKernel
{
public:
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
...
...
paddle/framework/tensor.h
浏览文件 @
bac1426d
...
...
@@ -35,7 +35,7 @@ class Tensor {
template
<
typename
T
>
T
*
data
()
const
{
const
T
*
data
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr has not been initialized. Call Tensor::mutable_data first."
);
...
...
@@ -58,6 +58,20 @@ class Tensor {
offset_
);
}
template
<
typename
T
,
// must be POD types
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
T
*
mutable_data
(
paddle
::
platform
::
Place
place
)
{
if
(
holder_
==
nullptr
||
!
(
holder_
->
Place
()
==
place
)
/* some versions of boost::variant don't have operator!= */
||
holder_
->
Size
()
<
product
(
dims_
)
*
sizeof
(
T
)
+
offset_
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
place
,
product
(
dims_
)
*
sizeof
(
T
)));
offset_
=
0
;
}
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
Ptr
())
+
offset_
);
}
size_t
NumElements
()
const
{
return
product
(
dims_
);
}
template
<
typename
T
,
size_t
NDIMS
>
...
...
paddle/operators/add_op.cc
浏览文件 @
bac1426d
#include
<paddle/framework/op_registry.h>
#include
<paddle/framework/tensor.h>
#include
<paddle/operators/add_op.h>
#include
"paddle/operators/add_op.h"
#include
"paddle/framework/op_registry.h"
#include
"paddle/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,9 +36,10 @@ The equation is: Out = X + Y
)DOC"
);
}
};
}
// namespace op
}
// namespace op
erators
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_OP_CPU_KERNEL
(
add_two
,
::
paddle
::
operators
::
AddKernel
<::
paddle
::
platform
::
CPUPlace
>
);
\ No newline at end of file
add_two
,
::
paddle
::
operators
::
AddKernel
<::
paddle
::
platform
::
CPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/add_op.cu
浏览文件 @
bac1426d
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
#define EIGEN_USE_GPU
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
REGISTER_OP_GPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
GPUPlace
>
);
\ No newline at end of file
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
\ No newline at end of file
paddle/operators/add_op.h
浏览文件 @
bac1426d
#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
#include "glog/logging.h"
#include "paddle/framework/operator.h"
//#include "paddle/operators/add_op_functor.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
>
// Place can be CPUPlace or GPUPlace
template
<
typename
Place
,
typename
DataType
>
class
AddKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
LOG
(
INFO
)
<<
"Add kernel in "
<<
typeid
(
Place
).
name
();
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
*
input0
=
context
.
Input
(
0
);
auto
*
input1
=
context
.
Input
(
1
);
auto
*
output
=
context
.
Output
(
0
);
output
->
mutable_data
<
DataType
>
(
Place
());
output
->
flat
<
T
>
().
device
(
*
(
context
.
get_eigen_device
<
Place
>
()))
=
input0
->
flat
<
T
>
()
+
input1
->
flat
<
T
>
();
}
};
}
// namespace op
}
// namespace op
erators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录