Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d649dbf4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d649dbf4
编写于
7月 17, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement add_op kernel
上级
bac1426d
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
58 addition
and
61 deletion
+58
-61
paddle/framework/operator.cc
paddle/framework/operator.cc
+5
-3
paddle/framework/operator.h
paddle/framework/operator.h
+29
-30
paddle/framework/tensor.h
paddle/framework/tensor.h
+3
-3
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+3
-3
paddle/operators/add_op.cu
paddle/operators/add_op.cu
+2
-3
paddle/operators/add_op.h
paddle/operators/add_op.h
+6
-7
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+5
-4
paddle/platform/device_context.h
paddle/platform/device_context.h
+5
-8
未找到文件。
paddle/framework/operator.cc
浏览文件 @
d649dbf4
...
...
@@ -18,13 +18,15 @@ namespace paddle {
namespace
framework
{
template
<
>
DeviceType
*
KernelContext
::
get_eigen_device
<
CPUPlace
>
()
{
return
device_context_
.
get_eigen_device
<
DeviceType
>
();
Eigen
::
DefaultDevice
*
OpKernel
::
KernelContext
::
get_eigen_device
<
platform
::
CPUPlace
,
Eigen
::
DefaultDevice
>
()
const
{
return
device_context_
.
get_eigen_device
<
Eigen
::
DefaultDevice
>
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
DeviceType
*
KernelContext
::
get_eigen_device
<
GPUPlace
>
()
{
DeviceType
*
OpKernel
::
KernelContext
::
get_eigen_device
<
platform
::
GPUPlace
>
()
const
{
return
device_context_
.
get_eigen_device
<
DeviceType
>
();
}
#endif
...
...
paddle/framework/operator.h
浏览文件 @
d649dbf4
...
...
@@ -33,13 +33,13 @@ template <typename T>
struct
EigenDeviceConverter
;
template
<
>
struct
EigenDeviceConverter
<
CPUPlace
>
{
struct
EigenDeviceConverter
<
platform
::
CPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
DefaultDevice
;
};
#ifndef PADDLE_ONLY_CPU
template
<
>
struct
EigenDeviceConverter
<
GPUPlace
>
{
struct
EigenDeviceConverter
<
platform
::
GPUPlace
>
{
using
EigenDeviceType
=
Eigen
::
GpuDevice
;
};
#endif
...
...
@@ -87,13 +87,15 @@ class OperatorBase {
AttributeMap
attrs_
;
};
/**
class
OpKernel
{
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class
KernelContext
{
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
...
...
@@ -107,19 +109,16 @@ class KernelContext {
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
platform
::
DeviceContext
&
device_context
()
const
{
return
device_context_
;
}
template
<
typename
PlaceType
,
typename
DeviceType
=
EigenDeviceConverter
<
PlaceType
>
::
EigenDeviceType
>
DeviceType
*
get_eigen_device
();
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
EigenDeviceConverter
<
PlaceType
>
::
EigenDeviceType
>
DeviceType
*
get_eigen_device
()
const
;
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
OpKernel
{
public:
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
...
...
paddle/framework/tensor.h
浏览文件 @
d649dbf4
...
...
@@ -35,7 +35,7 @@ class Tensor {
template
<
typename
T
>
const
T
*
data
()
const
{
T
*
data
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"Tenosr has not been initialized. Call Tensor::mutable_data first."
);
...
...
@@ -90,7 +90,7 @@ class Tensor {
// flat to rank = 1
template
<
typename
T
>
typename
TTypes
<
T
>::
Flat
flat
()
{
return
shaped
<
T
,
1
>
(
{
NumElements
()}
);
return
shaped
<
T
,
1
>
(
make_ddim
({
static_cast
<
int
>
(
NumElements
())})
);
}
// to TensorType Vec
...
...
@@ -114,7 +114,7 @@ class Tensor {
template
<
typename
T
>
typename
TTypes
<
T
>::
ConstFlat
flat
()
const
{
return
shaped
<
T
,
1
>
(
{
NumElements
()}
);
return
shaped
<
T
,
1
>
(
make_ddim
({
static_cast
<
int
>
(
NumElements
())})
);
}
template
<
typename
T
>
...
...
paddle/operators/add_op.cc
浏览文件 @
d649dbf4
...
...
@@ -40,6 +40,6 @@ The equation is: Out = X + Y
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_OP_CPU_KERNEL
(
add_two
,
::
paddle
::
operators
::
AddKernel
<::
paddle
::
platform
::
CPUPlace
,
float
>
);
\ No newline at end of file
typedef
paddle
::
operators
::
AddKernel
<::
paddle
::
platform
::
CPUPlace
,
float
>
AddKernel_CPU_float
;
REGISTER_OP_CPU_KERNEL
(
add_two
,
AddKernel_CPU_float
);
\ No newline at end of file
paddle/operators/add_op.cu
浏览文件 @
d649dbf4
#define EIGEN_USE_GPU
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
typedef
paddle
::
operators
::
AddKernel
<::
paddle
::
platform
::
GPUPlace
,
float
>
AddKernel_GPU_float
;
REGISTER_OP_GPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
\ No newline at end of file
AddKernel_GPU_float
);
\ No newline at end of file
paddle/operators/add_op.h
浏览文件 @
d649dbf4
...
...
@@ -6,19 +6,18 @@
namespace
paddle
{
namespace
operators
{
// Place can be CPUPlace or GPUPlace
template
<
typename
Place
,
typename
DataType
>
template
<
typename
Place
,
typename
T
>
class
AddKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
auto
*
input0
=
context
.
Input
(
0
);
auto
*
input1
=
context
.
Input
(
1
);
auto
input0
=
context
.
Input
(
0
)
->
Get
<
framework
::
Tensor
>
();
auto
input1
=
context
.
Input
(
1
)
->
Get
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
)
->
GetMutable
<
framework
::
Tensor
>
();
auto
*
output
=
context
.
Output
(
0
);
output
->
mutable_data
<
DataType
>
(
Place
());
output
->
mutable_data
<
T
>
(
Place
());
output
->
flat
<
T
>
().
device
(
*
(
context
.
get_eigen_device
<
Place
>
()))
=
input0
->
flat
<
T
>
()
+
input1
->
flat
<
T
>
();
input0
.
flat
<
T
>
()
+
input1
.
flat
<
T
>
();
}
};
...
...
paddle/platform/device_context.cc
浏览文件 @
d649dbf4
...
...
@@ -15,14 +15,15 @@ namespace paddle {
namespace
platform
{
template
<
>
Eigen
::
DefaultDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
{
return
reinterpret_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
Eigen
::
DefaultDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
const
{
return
reinterpret_cast
<
const
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
{
return
reinterpret_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
#endif
...
...
paddle/platform/device_context.h
浏览文件 @
d649dbf4
...
...
@@ -32,17 +32,14 @@ class DeviceContext {
virtual
Place
GetPlace
()
const
=
0
;
template
<
typename
DeviceType
>
DeviceType
*
get_eigen_device
();
DeviceType
*
get_eigen_device
()
const
;
};
class
CPUDeviceContext
:
public
DeviceContext
{
public:
Eigen
::
DefaultDevice
*
eigen_device
()
{
if
(
!
eigen_device_
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
return
eigen_device_
.
get
();
}
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
Eigen
::
DefaultDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Place
GetPlace
()
const
override
{
Place
retv
=
CPUPlace
();
...
...
@@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
{
return
eigen_device_
.
get
();
}
Eigen
::
GpuDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录