未验证 提交 c4203751 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #15806 from sneaxiy/fix-compiler

Fix compiler.py place compare bug
......@@ -106,6 +106,11 @@ bool IsCompiledWithDIST() {
#endif
}
template <typename PlaceType1, typename PlaceType2>
static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) {
return paddle::platform::Place(p1) == paddle::platform::Place(p2);
}
PYBIND11_MODULE(core, m) {
// Not used, just make sure cpu_info.cc is linked.
paddle::platform::CpuTotalPhysicalMemory();
......@@ -732,23 +737,45 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
#endif
})
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>())
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
.def("__init__",
[](platform::CUDAPinnedPlace &) {
[](platform::CUDAPinnedPlace &self) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
#endif
new (&self) platform::CUDAPinnedPlace();
})
.def("_equals", &IsSamePlace<platform::CUDAPinnedPlace, platform::Place>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place")
.def(py::init<>())
.def("_equals", &IsSamePlace<platform::Place, platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); })
.def("gpu_device_id",
......
......@@ -220,7 +220,7 @@ class CompiledProgram(object):
if self._compiled:
if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope")
if place and self._place != place:
if place and not self._place._equals(place):
raise ValueError("Cannot compile with different place")
return self
self._compiled = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册