Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4ecf68e0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4ecf68e0
编写于
7月 25, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in register gpu OpKernel
上级
358261f0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
14 addition
and
9 deletion
+14
-9
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+4
-3
paddle/framework/operator.h
paddle/framework/operator.h
+5
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+3
-1
paddle/pybind/tensor_bind.h
paddle/pybind/tensor_bind.h
+2
-4
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
4ecf68e0
...
...
@@ -403,15 +403,16 @@ class GradOpRegisterHelper {
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__
{
\
__op_kernel_register__##type##__
() {
\
struct __op_kernel_register__##type##__
##DEVICE_TYPE##__ {
\
__op_kernel_register__##type##__
##DEVICE_TYPE##__() {
\
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)
...
...
paddle/framework/operator.h
浏览文件 @
4ecf68e0
...
...
@@ -199,7 +199,11 @@ class OperatorWithKernel : public OperatorBase {
place_
=
dev_ctx
.
GetPlace
();
}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
place_
==
o
.
place_
;
}
// bool operator==(const OpKernelKey& o) const { return place_ == o.place_;
// }
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
);
}
};
struct
OpKernelHash
{
...
...
paddle/pybind/pybind.cc
浏览文件 @
4ecf68e0
...
...
@@ -80,9 +80,11 @@ PYBIND11_PLUGIN(core) {
self
.
mutable_data
<
int
>
(
place
);
})
.
def
(
"set"
,
paddle
::
pybind
::
PyCPUTensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCUDATensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCPUTensorSetFromArray
<
int
>
)
#ifndef PADDLE_ONLY_CPU
.
def
(
"set"
,
paddle
::
pybind
::
PyCUDATensorSetFromArray
<
float
>
)
.
def
(
"set"
,
paddle
::
pybind
::
PyCUDATensorSetFromArray
<
int
>
)
#endif
.
def
(
"shape"
,
[](
pd
::
Tensor
&
self
)
{
return
pd
::
vectorize
(
self
.
dims
());
});
...
...
paddle/pybind/tensor_bind.h
浏览文件 @
4ecf68e0
...
...
@@ -42,9 +42,6 @@ template <size_t I, typename... ARGS>
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
py
::
buffer_info
operator
()(
framework
::
Tensor
&
tensor
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
tensor
.
holder_
->
place
()),
"Only CPU tensor can cast to numpy array"
);
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
holder_
->
type
())
{
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
std
::
vector
<
size_t
>
dims_outside
;
...
...
@@ -99,6 +96,7 @@ void PyCPUTensorSetFromArray(
std
::
memcpy
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
());
}
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
void
PyCUDATensorSetFromArray
(
framework
::
Tensor
&
self
,
...
...
@@ -112,10 +110,10 @@ void PyCUDATensorSetFromArray(
self
.
Resize
(
framework
::
make_ddim
(
dims
));
auto
*
dst
=
self
.
mutable_data
<
T
>
(
place
);
std
::
memcpy
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
());
paddle
::
platform
::
GpuMemcpySync
(
dst
,
array
.
data
(),
sizeof
(
T
)
*
array
.
size
(),
cudaMemcpyHostToDevice
);
}
#endif
}
// namespace pybind
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录