From d331e97af85f4ef188edf52535bb04d0ecf26138 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 20 Feb 2019 11:08:38 +0800 Subject: [PATCH] fix compiler place compare test=develop --- paddle/fluid/pybind/pybind.cc | 29 ++++++++++++++++++++++++++++- python/paddle/fluid/compiler.py | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c50c38160..d8e57a1ac 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -106,6 +106,11 @@ bool IsCompiledWithDIST() { #endif } +template +static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) { + return paddle::platform::Place(p1) == paddle::platform::Place(p2); +} + PYBIND11_MODULE(core, m) { // Not used, just make sure cpu_info.cc is linked. paddle::platform::CpuTotalPhysicalMemory(); @@ -732,23 +737,45 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_THROW("Cannot use CUDAPlace in CPU only version"); #endif }) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("__str__", string::to_string); py::class_(m, "CPUPlace") .def(py::init<>()) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("__str__", string::to_string); py::class_(m, "CUDAPinnedPlace") .def("__init__", - [](platform::CUDAPinnedPlace &) { + [](platform::CUDAPinnedPlace &self) { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version"); #endif + new (&self) platform::CUDAPinnedPlace(); }) + .def("_equals", &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("_equals", + &IsSamePlace) + .def("_equals", + &IsSamePlace) .def("__str__", string::to_string); py::class_(m, "Place") .def(py::init<>()) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) + .def("_equals", &IsSamePlace) .def("is_gpu_place", [](platform::Place &self) { return platform::is_gpu_place(self); }) .def("gpu_device_id", diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index b24cec044..0fecff81c 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -220,7 +220,7 @@ class CompiledProgram(object): if self._compiled: if scope and self._scope != scope: raise ValueError("Cannot compile with different scope") - if place and self._place != place: + if place and not self._place._equals(place): raise ValueError("Cannot compile with different place") return self self._compiled = True -- GitLab