Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ccafd2e5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ccafd2e5
编写于
4月 24, 2022
作者:
R
ronnywang
提交者:
GitHub
4月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomDevice] add eager mode support (#42034)
上级
0e0f7da6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
17 addition
and
5 deletion
+17
-5
paddle/fluid/pybind/eager.cc
paddle/fluid/pybind/eager.cc
+4
-1
paddle/fluid/pybind/eager_utils.cc
paddle/fluid/pybind/eager_utils.cc
+7
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+6
-3
未找到文件。
paddle/fluid/pybind/eager.cc
浏览文件 @
ccafd2e5
...
@@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array,
...
@@ -146,10 +146,13 @@ void InitTensorWithNumpyValue(TensorObject* self, const py::object& array,
zero_copy
);
zero_copy
);
}
else
if
(
platform
::
is_npu_place
(
place
))
{
}
else
if
(
platform
::
is_npu_place
(
place
))
{
SetTensorFromPyArray
<
platform
::
NPUPlace
>
(
impl_ptr
,
array
,
place
,
zero_copy
);
SetTensorFromPyArray
<
platform
::
NPUPlace
>
(
impl_ptr
,
array
,
place
,
zero_copy
);
}
else
if
(
platform
::
is_custom_place
(
place
))
{
SetTensorFromPyArray
<
platform
::
CustomPlace
>
(
impl_ptr
,
array
,
place
,
zero_copy
);
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Place should be one of "
"Place should be one of "
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"
));
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace
/CustomPlace
"
));
}
}
}
}
...
...
paddle/fluid/pybind/eager_utils.cc
浏览文件 @
ccafd2e5
...
@@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype;
...
@@ -46,6 +46,7 @@ extern PyTypeObject* g_cpuplace_pytype;
extern
PyTypeObject
*
g_xpuplace_pytype
;
extern
PyTypeObject
*
g_xpuplace_pytype
;
extern
PyTypeObject
*
g_npuplace_pytype
;
extern
PyTypeObject
*
g_npuplace_pytype
;
extern
PyTypeObject
*
g_cudapinnedplace_pytype
;
extern
PyTypeObject
*
g_cudapinnedplace_pytype
;
extern
PyTypeObject
*
g_customplace_pytype
;
extern
PyTypeObject
*
g_framework_tensor_pytype
;
extern
PyTypeObject
*
g_framework_tensor_pytype
;
extern
PyTypeObject
*
g_framework_lodtensorarray_pytype
;
extern
PyTypeObject
*
g_framework_lodtensorarray_pytype
;
extern
PyTypeObject
*
g_custom_op_kernel_ctx_pytype
;
extern
PyTypeObject
*
g_custom_op_kernel_ctx_pytype
;
...
@@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
...
@@ -377,10 +378,15 @@ platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) {
}
else
if
(
PyObject_IsInstance
(
}
else
if
(
PyObject_IsInstance
(
obj
,
reinterpret_cast
<
PyObject
*>
(
g_cudapinnedplace_pytype
)))
{
obj
,
reinterpret_cast
<
PyObject
*>
(
g_cudapinnedplace_pytype
)))
{
place
=
::
pybind11
::
handle
(
obj
).
cast
<
platform
::
CUDAPinnedPlace
>
();
place
=
::
pybind11
::
handle
(
obj
).
cast
<
platform
::
CUDAPinnedPlace
>
();
}
else
if
(
PyObject_IsInstance
(
obj
,
reinterpret_cast
<
PyObject
*>
(
g_customplace_pytype
)))
{
place
=
::
pybind11
::
handle
(
obj
).
cast
<
platform
::
CustomPlace
>
();
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"argument (position %d) must be "
"argument (position %d) must be "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), "
"one "
"of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace,"
"CustomPlace), "
"but got %s"
,
"but got %s"
,
arg_pos
+
1
,
reinterpret_cast
<
PyTypeObject
*>
(
obj
->
ob_type
)
->
tp_name
));
arg_pos
+
1
,
reinterpret_cast
<
PyTypeObject
*>
(
obj
->
ob_type
)
->
tp_name
));
}
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ccafd2e5
...
@@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr;
...
@@ -193,6 +193,7 @@ PyTypeObject *g_xpuplace_pytype = nullptr;
PyTypeObject
*
g_npuplace_pytype
=
nullptr
;
PyTypeObject
*
g_npuplace_pytype
=
nullptr
;
PyTypeObject
*
g_cudapinnedplace_pytype
=
nullptr
;
PyTypeObject
*
g_cudapinnedplace_pytype
=
nullptr
;
PyTypeObject
*
g_mluplace_pytype
=
nullptr
;
PyTypeObject
*
g_mluplace_pytype
=
nullptr
;
PyTypeObject
*
g_customplace_pytype
=
nullptr
;
PyTypeObject
*
g_framework_tensor_pytype
=
nullptr
;
PyTypeObject
*
g_framework_tensor_pytype
=
nullptr
;
PyTypeObject
*
g_framework_lodtensorarray_pytype
=
nullptr
;
PyTypeObject
*
g_framework_lodtensorarray_pytype
=
nullptr
;
PyTypeObject
*
g_custom_op_kernel_ctx_pytype
=
nullptr
;
PyTypeObject
*
g_custom_op_kernel_ctx_pytype
=
nullptr
;
...
@@ -2125,8 +2126,8 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2125,8 +2126,8 @@ All parameter, weight, gradient are variables in Paddle.
#endif
#endif
return
devices
;
return
devices
;
});
});
py
::
class_
<
platform
::
CustomPlace
>
(
m
,
"CustomPlace"
,
py
::
class_
<
platform
::
CustomPlace
>
customplace
(
m
,
"CustomPlace"
,
R"DOC(
R"DOC(
CustomPlace is a descriptor of a device.
CustomPlace is a descriptor of a device.
It represents a custom device on which a tensor will be allocated and a model will run.
It represents a custom device on which a tensor will be allocated and a model will run.
...
@@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2135,7 +2136,9 @@ All parameter, weight, gradient are variables in Paddle.
import paddle
import paddle
fake_cpu_place = paddle.CustomPlace("FakeCPU", 0)
fake_cpu_place = paddle.CustomPlace("FakeCPU", 0)
)DOC"
)
)DOC"
);
g_customplace_pytype
=
reinterpret_cast
<
PyTypeObject
*>
(
customplace
.
ptr
());
customplace
.
def
(
"__init__"
,
.
def
(
"__init__"
,
[](
platform
::
CustomPlace
&
self
,
const
std
::
string
&
device_type
,
[](
platform
::
CustomPlace
&
self
,
const
std
::
string
&
device_type
,
int
dev_id
)
{
int
dev_id
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录