提交 1d8fe2a2 编写于 作者: Y Yu Yang 提交者: dzhwinter

Enhance device context pool (#9293)

上级 64c5c8f8
...@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and ...@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <unordered_set>
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Get( platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
PADDLE_THROW( PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU " "'Place' is not supported, Please re-compile with WITH_GPU "
"option"); "option");
} }
return it->second; return it->second.get();
} }
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) { using PtrType = std::unique_ptr<DeviceContext>;
if (platform::is_cpu_place(places[i])) { std::unordered_set<Place, PlaceHash> set;
for (auto& p : places) {
set.insert(p);
}
for (auto& p : set) {
if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::MKLDNNDeviceContext( p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
boost::get<platform::CPUPlace>(places[i])));
#else #else
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::CPUDeviceContext( p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
boost::get<platform::CPUPlace>(places[i])));
#endif #endif
} else if (platform::is_gpu_place(places[i])) { } else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::CUDADeviceContext( p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
boost::get<platform::CUDAPlace>(places[i])));
#else #else
PADDLE_THROW( PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
...@@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
std::lock_guard<std::mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
} }
......
...@@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::mutex mutex_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
...@@ -159,7 +160,7 @@ class DeviceContextPool { ...@@ -159,7 +160,7 @@ class DeviceContextPool {
} }
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place); platform::DeviceContext* Get(const platform::Place& place);
template <typename Place> template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace( const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
...@@ -172,19 +173,8 @@ class DeviceContextPool { ...@@ -172,19 +173,8 @@ class DeviceContextPool {
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8; std::unordered_map<const platform::Place,
struct Hash { std::unique_ptr<platform::DeviceContext>, PlaceHash>
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() << LEFT_SHIFT;
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_; device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
}; };
......
...@@ -65,6 +65,18 @@ bool is_cpu_place(const Place &); ...@@ -65,6 +65,18 @@ bool is_cpu_place(const Place &);
bool places_are_same_class(const Place &, const Place &); bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &); bool is_same_place(const Place &, const Place &);
struct PlaceHash {
std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
std::hash<int> ihash;
size_t dev_id = 0;
if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device;
}
return ihash(dev_id << num_dev_bits | p.which());
}
};
std::ostream &operator<<(std::ostream &, const Place &); std::ostream &operator<<(std::ostream &, const Place &);
template <typename Visitor> template <typename Visitor>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册