未验证 提交 2d0e5592 编写于 作者: Y yuyang18

Use std::map for Place <--> DeviceContext

上级 f7fd711e
...@@ -124,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -124,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
std::function<void()> method = callback; std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
ordered_ctxes.emplace(p.second, p.first);
}
for (auto &p : ordered_ctxes) {
method = [method, p, this]() { method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent( static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.second).device), events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method); method);
}; };
} }
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -92,9 +92,7 @@ class OpHandleBase { ...@@ -92,9 +92,7 @@ class OpHandleBase {
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *, std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
platform::PlaceHash>
dev_ctxes_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_; std::unordered_map<int, cudaEvent_t> events_;
......
...@@ -54,8 +54,7 @@ struct ReduceLoDTensor { ...@@ -54,8 +54,7 @@ struct ReduceLoDTensor {
inline void GatherSelectedRows( inline void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_, const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places, const std::vector<platform::Place> &in_places,
const std::unordered_map<platform::Place, platform::DeviceContext *, const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
platform::PlaceHash> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selecte_rows) { const platform::Place &out_place, SelectedRows *dst_selecte_rows) {
PADDLE_ENFORCE(!src_selecte_rows_.empty()); PADDLE_ENFORCE(!src_selecte_rows_.empty());
......
...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <set>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -35,7 +36,7 @@ DeviceContextPool::DeviceContextPool( ...@@ -35,7 +36,7 @@ DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
using PtrType = std::unique_ptr<DeviceContext>; using PtrType = std::unique_ptr<DeviceContext>;
std::unordered_set<Place, PlaceHash> set; std::set<Place> set;
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
} }
......
...@@ -27,12 +27,12 @@ limitations under the License. */ ...@@ -27,12 +27,12 @@ limitations under the License. */
#include <mkldnn.hpp> #include <mkldnn.hpp>
#endif #endif
#include <map>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -201,9 +201,7 @@ class DeviceContextPool { ...@@ -201,9 +201,7 @@ class DeviceContextPool {
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
std::unordered_map<const platform::Place, std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_;
std::unique_ptr<platform::DeviceContext>, PlaceHash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
}; };
......
...@@ -30,6 +30,7 @@ struct CPUPlace { ...@@ -30,6 +30,7 @@ struct CPUPlace {
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const CPUPlace &) const { return true; } inline bool operator==(const CPUPlace &) const { return true; }
inline bool operator!=(const CPUPlace &) const { return false; } inline bool operator!=(const CPUPlace &) const { return false; }
inline bool operator<(const CPUPlace &) const { return false; }
}; };
struct CUDAPlace { struct CUDAPlace {
...@@ -42,6 +43,7 @@ struct CUDAPlace { ...@@ -42,6 +43,7 @@ struct CUDAPlace {
return device == o.device; return device == o.device;
} }
inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); } inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
inline bool operator<(const CUDAPlace &o) const { return device < o.device; }
int device; int device;
}; };
...@@ -52,6 +54,7 @@ struct CUDAPinnedPlace { ...@@ -52,6 +54,7 @@ struct CUDAPinnedPlace {
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const CUDAPinnedPlace &) const { return true; } inline bool operator==(const CUDAPinnedPlace &) const { return true; }
inline bool operator!=(const CUDAPinnedPlace &) const { return false; } inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
inline bool operator<(const CUDAPinnedPlace &) const { return false; }
}; };
struct IsCUDAPlace : public boost::static_visitor<bool> { struct IsCUDAPlace : public boost::static_visitor<bool> {
...@@ -89,18 +92,6 @@ bool is_cuda_pinned_place(const Place &); ...@@ -89,18 +92,6 @@ bool is_cuda_pinned_place(const Place &);
bool places_are_same_class(const Place &, const Place &); bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &); bool is_same_place(const Place &, const Place &);
struct PlaceHash {
std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
std::hash<int> ihash;
size_t dev_id = 0;
if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device;
}
return ihash(dev_id << num_dev_bits | p.which());
}
};
std::ostream &operator<<(std::ostream &, const Place &); std::ostream &operator<<(std::ostream &, const Place &);
template <typename Visitor> template <typename Visitor>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册