提交 5c333e41 编写于 作者: Y Yu Yang

Add dctor for dev_ctx

上级 15f5f10e
......@@ -35,18 +35,18 @@ using details::VarHandleBase;
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads)
: pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
explicit ParallelExecutorPrivate(size_t num_threads,
const std::vector<platform::Place> &places)
: places_(places),
fetch_dev_ctxs_(places),
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_dev_ctxs_;
std::vector<Scope *> local_scopes_;
Scope *global_scope_;
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
fetch_dev_ctxs_;
platform::Place main_place_;
......@@ -219,20 +219,9 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate(num_threads)) {
member_->places_ = places;
: member_(new ParallelExecutorPrivate(num_threads, places)) {
member_->global_scope_ = scope;
if (platform::is_cpu_place(places[0])) {
member_->fetch_dev_ctxs_[places[0]] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(places[0]));
} else {
for (auto &p : member_->places_) {
member_->fetch_dev_ctxs_[p] =
new platform::CUDADeviceContext(boost::get<platform::CUDAPlace>(p));
}
}
// Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]);
exe.Run(startup_program, scope, 0);
......@@ -509,7 +498,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
// FIXME: Use new device context
for (auto &p : member_->places_) {
op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p];
op->dev_ctx_[p] = member_->fetch_dev_ctxs_.Get(p);
}
for (auto *var : vars) {
......
......@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include <unordered_set>
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) {
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
return it->second;
return it->second.get();
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
using PtrType = std::unique_ptr<DeviceContext>;
std::unordered_set<Place, PlaceHash> set;
for (auto& p : places) {
set.insert(p);
}
for (auto& p : set) {
if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN
device_contexts_.emplace(places[i],
new platform::MKLDNNDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new MKLDNNDeviceContext(boost::get<CPUPlace>(p))));
#else
device_contexts_.emplace(places[i],
new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new CPUDeviceContext(boost::get<CPUPlace>(p))));
#endif
} else if (platform::is_gpu_place(places[i])) {
} else if (platform::is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i],
new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(places[i])));
device_contexts_.emplace(
p, PtrType(new CUDADeviceContext(boost::get<CUDAPlace>(p))));
#else
PADDLE_THROW(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
......
......@@ -160,7 +160,7 @@ class DeviceContextPool {
}
/*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place);
platform::DeviceContext* Get(const platform::Place& place);
template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
......@@ -173,19 +173,8 @@ class DeviceContextPool {
private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() << LEFT_SHIFT;
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
Hash>
std::unordered_map<const platform::Place,
std::unique_ptr<platform::DeviceContext>, PlaceHash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
......
......@@ -67,12 +67,13 @@ bool is_same_place(const Place &, const Place &);
struct PlaceHash {
std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
std::hash<int> ihash;
size_t dev_id = 0;
if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device;
}
return ihash(dev_id << 2 | p.which());
return ihash(dev_id << num_dev_bits | p.which());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册