未验证 提交 c4203751 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #15806 from sneaxiy/fix-compiler

Fix compiler.py place compare bug
...@@ -106,6 +106,11 @@ bool IsCompiledWithDIST() { ...@@ -106,6 +106,11 @@ bool IsCompiledWithDIST() {
#endif #endif
} }
template <typename PlaceType1, typename PlaceType2>
static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) {
return paddle::platform::Place(p1) == paddle::platform::Place(p2);
}
PYBIND11_MODULE(core, m) { PYBIND11_MODULE(core, m) {
// Not used, just make sure cpu_info.cc is linked. // Not used, just make sure cpu_info.cc is linked.
paddle::platform::CpuTotalPhysicalMemory(); paddle::platform::CpuTotalPhysicalMemory();
...@@ -732,23 +737,45 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -732,23 +737,45 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_THROW("Cannot use CUDAPlace in CPU only version"); PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
#endif #endif
}) })
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPlace &>); .def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace") py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>()) .def(py::init<>())
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace") py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
.def("__init__", .def("__init__",
[](platform::CUDAPinnedPlace &) { [](platform::CUDAPinnedPlace &self) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version"); PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
#endif #endif
new (&self) platform::CUDAPinnedPlace();
}) })
.def("_equals", &IsSamePlace<platform::CUDAPinnedPlace, platform::Place>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>); .def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place") py::class_<platform::Place>(m, "Place")
.def(py::init<>()) .def(py::init<>())
.def("_equals", &IsSamePlace<platform::Place, platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("is_gpu_place", .def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); }) [](platform::Place &self) { return platform::is_gpu_place(self); })
.def("gpu_device_id", .def("gpu_device_id",
......
...@@ -220,7 +220,7 @@ class CompiledProgram(object): ...@@ -220,7 +220,7 @@ class CompiledProgram(object):
if self._compiled: if self._compiled:
if scope and self._scope != scope: if scope and self._scope != scope:
raise ValueError("Cannot compile with different scope") raise ValueError("Cannot compile with different scope")
if place and self._place != place: if place and not self._place._equals(place):
raise ValueError("Cannot compile with different place") raise ValueError("Cannot compile with different place")
return self return self
self._compiled = True self._compiled = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册