提交 5c333e41 编写于 作者: Y Yu Yang

Add dctor for dev_ctx

上级 15f5f10e
...@@ -35,18 +35,18 @@ using details::VarHandleBase; ...@@ -35,18 +35,18 @@ using details::VarHandleBase;
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads) explicit ParallelExecutorPrivate(size_t num_threads,
: pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {} const std::vector<platform::Place> &places)
: places_(places),
fetch_dev_ctxs_(places),
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_dev_ctxs_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
Scope *global_scope_; Scope *global_scope_;
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
fetch_dev_ctxs_;
platform::Place main_place_; platform::Place main_place_;
...@@ -219,20 +219,9 @@ ParallelExecutor::ParallelExecutor( ...@@ -219,20 +219,9 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program, const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope) const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate(num_threads)) { : member_(new ParallelExecutorPrivate(num_threads, places)) {
member_->places_ = places;
member_->global_scope_ = scope; member_->global_scope_ = scope;
if (platform::is_cpu_place(places[0])) {
member_->fetch_dev_ctxs_[places[0]] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(places[0]));
} else {
for (auto &p : member_->places_) {
member_->fetch_dev_ctxs_[p] =
new platform::CUDADeviceContext(boost::get<platform::CUDAPlace>(p));
}
}
// Step 1. RunStartupProgram and Bcast the params to devs. // Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]); Executor exe(places[0]);
exe.Run(startup_program, scope, 0); exe.Run(startup_program, scope, 0);
...@@ -509,7 +498,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -509,7 +498,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
// FIXME: Use new device context // FIXME: Use new device context
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p]; op->dev_ctx_[p] = member_->fetch_dev_ctxs_.Get(p);
} }
for (auto *var : vars) { for (auto *var : vars) {
......
...@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and ...@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <unordered_set>
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Get( platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
PADDLE_THROW( PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU " "'Place' is not supported, Please re-compile with WITH_GPU "
"option"); "option");
} }
return it->second; return it->second.get();
} }
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) { using PtrType = std::unique_ptr<DeviceContext>;
if (platform::is_cpu_place(places[i])) { std::unordered_set<Place, PlaceHash> set;
for (auto& p : places) {
set.insert(p);
}
for (auto& p : set) {
if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::MKLDNNDeviceContext( p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
boost::get<platform::CPUPlace>(places[i])));
#else #else
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::CPUDeviceContext( p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
boost::get<platform::CPUPlace>(places[i])));
#endif #endif
} else if (platform::is_gpu_place(places[i])) { } else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i], device_contexts_.emplace(
new platform::CUDADeviceContext( p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
boost::get<platform::CUDAPlace>(places[i])));
#else #else
PADDLE_THROW( PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU " "'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
......
...@@ -160,7 +160,7 @@ class DeviceContextPool { ...@@ -160,7 +160,7 @@ class DeviceContextPool {
} }
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place); platform::DeviceContext* Get(const platform::Place& place);
template <typename Place> template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace( const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
...@@ -173,19 +173,8 @@ class DeviceContextPool { ...@@ -173,19 +173,8 @@ class DeviceContextPool {
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8; std::unordered_map<const platform::Place,
struct Hash { std::unique_ptr<platform::DeviceContext>, PlaceHash>
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() << LEFT_SHIFT;
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_; device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
}; };
......
...@@ -67,12 +67,13 @@ bool is_same_place(const Place &, const Place &); ...@@ -67,12 +67,13 @@ bool is_same_place(const Place &, const Place &);
struct PlaceHash { struct PlaceHash {
std::size_t operator()(const Place &p) const { std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
std::hash<int> ihash; std::hash<int> ihash;
size_t dev_id = 0; size_t dev_id = 0;
if (is_gpu_place(p)) { if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device; dev_id = boost::get<CUDAPlace>(p).device;
} }
return ihash(dev_id << 2 | p.which()); return ihash(dev_id << num_dev_bits | p.which());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册