提交 15e8c80e 编写于 作者: Y Yu Yang 提交者: dzhwinter

Rename API of DeviceContext (#7055)

* Rename API of DeviceContext

Make them as usual names.

* Rename API of DeviceContext

Make them as usual names.

* Fix compile

* Fix compile

* Fix compile

* Fix compile

* Fix compile
上级 c31cbae5
...@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default."; LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
} }
platform::DeviceContextPool::Create(places); platform::DeviceContextPool::Init(places);
return true; return true;
} }
......
...@@ -388,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -388,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Borrow(place); auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
......
...@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap, ...@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap,
} }
/* the size of inputs for norm-layer is 1 */ /* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1UL); CHECK_EQ(config_.inputs_size(), 1);
const NormConfig& conf = config_.inputs(0).norm_conf(); const NormConfig& conf = config_.inputs(0).norm_conf();
localSize_ = conf.size(); localSize_ = conf.size();
alpha_ = conf.scale(); alpha_ = conf.scale();
......
...@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase { ...@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
size_t offset; size_t offset;
if (platform::is_gpu_place(i_tensor.place())) { if (platform::is_gpu_place(i_tensor.place())) {
......
...@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
} }
auto slice = out->Slice(out_offset, out_offset + len); auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice); dev_ctx, &slice);
......
...@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase { ...@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase {
out != nullptr, out != nullptr,
"The Output(Out) should not be null if the Input(X) is set."); "The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
} }
......
...@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Borrow(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx); framework::ExecutionContext ctx(*this, scope, dev_ctx);
......
...@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
void CondOp::Run(const Scope& scope, const platform::Place& place) const { void CondOp::Run(const Scope& scope, const platform::Place& place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Borrow(place); auto& dev_ctx = *pool.Get(place);
PrepareDataForSubnet(scope, dev_ctx); PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope); std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
......
...@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase { ...@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase {
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(feed_item, place, dev_ctx, out_item); framework::CopyFrom(feed_item, place, dev_ctx, out_item);
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
......
...@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase { ...@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait(); dev_ctx.Wait();
......
...@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase {
out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
math::set_constant(dev_ctx, &out, value); math::set_constant(dev_ctx, &out, value);
} }
}; };
......
...@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase { ...@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase {
if (!force_cpu && platform::is_gpu_place(place)) { if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out // Copy tensor to out
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(tensor, place, dev_ctx, &out); framework::CopyFrom(tensor, place, dev_ctx, &out);
} }
} }
......
...@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase { ...@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor); framework::DeserializeFromStream(fin, tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
// copy CPU to GPU // copy CPU to GPU
......
...@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto slice = out[i].Slice(static_cast<int>(offset), auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len)); static_cast<int>(offset + len));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin), framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)), static_cast<int>(each_range.end)),
......
...@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
......
...@@ -305,7 +305,7 @@ int main(int argc, char **argv) { ...@@ -305,7 +305,7 @@ int main(int argc, char **argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase {
false /*create_local_scope*/); false /*create_local_scope*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// Copy inside::output -> outside::output // Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output // outside::output[seq_offset: seq_offset + 1] = inside::output
...@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase {
auto *program = block->Program(); auto *program = block->Program();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
......
...@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
auto x_sliced = x.Slice(x_offset, x_offset + len); auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len; out_offset += len;
return out_offset; return out_offset;
......
...@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase { ...@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase {
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
......
...@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
if (dout_var == nullptr) { // dx_tensor fill zero if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f); math::set_constant(dev_ctx, &dx_tensor, 0.0f);
......
...@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto &x_lod = x.lod(); auto &x_lod = x.lod();
auto &mask_dim = mask.dims(); auto &mask_dim = mask.dims();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()}; std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) { if (platform::is_cpu_place(mask.place())) {
......
...@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp { ...@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp {
if (x_tensor.memory_size() > 0) { if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset); auto *out_tensor = &out->at(offset);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
CopyFrom(x_tensor, place, dev_ctx, out_tensor); CopyFrom(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod()); out_tensor->set_lod(x_tensor.lod());
...@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp {
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor); framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} else { } else {
......
...@@ -17,7 +17,7 @@ namespace platform { ...@@ -17,7 +17,7 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Borrow( const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) { const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
...@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow( ...@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow(
return it->second; return it->second;
} }
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto it = device_contexts_.find(place);
if (it != device_contexts_.end()) {
borrowed_contexts.emplace_back(it->second);
} else {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
}
return borrowed_contexts;
}
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
......
...@@ -109,13 +109,13 @@ class DeviceContextPool { ...@@ -109,13 +109,13 @@ class DeviceContextPool {
public: public:
explicit DeviceContextPool(const std::vector<platform::Place>& places); explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Get() { static DeviceContextPool& Instance() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool; return *pool;
} }
/*! \brief Create should only called by Init function */ /*! \brief Create should only called by Init function */
static DeviceContextPool& Create(const std::vector<platform::Place>& places) { static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
if (pool == nullptr) { if (pool == nullptr) {
pool = new DeviceContextPool(places); pool = new DeviceContextPool(places);
} }
...@@ -123,13 +123,7 @@ class DeviceContextPool { ...@@ -123,13 +123,7 @@ class DeviceContextPool {
} }
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
const platform::DeviceContext* Borrow(const platform::Place& place); const platform::DeviceContext* Get(const platform::Place& place);
/*! \brief Return handle of multi-device context. */
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places);
~DeviceContextPool() {}
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
......
...@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) { ...@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
using paddle::platform::CPUPlace; using paddle::platform::CPUPlace;
using paddle::platform::CUDAPlace; using paddle::platform::CUDAPlace;
DeviceContextPool& pool = DeviceContextPool::Get(); DeviceContextPool& pool = DeviceContextPool::Instance();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace()); auto cpu_dev_ctx1 = pool.Get(CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace()); auto cpu_dev_ctx2 = pool.Get(CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1); ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1);
std::vector<Place> gpu_places; std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount(); int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(CUDAPlace(i)); auto dev_ctx = pool.Get(CUDAPlace(i));
} ASSERT_NE(dev_ctx, nullptr);
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as CUDAPlace(i)
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
} }
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places; std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace()); places.emplace_back(paddle::platform::CPUPlace());
...@@ -109,7 +94,7 @@ int main(int argc, char** argv) { ...@@ -109,7 +94,7 @@ int main(int argc, char** argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
......
...@@ -144,7 +144,7 @@ int main(int argc, char** argv) { ...@@ -144,7 +144,7 @@ int main(int argc, char** argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
......
...@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>( auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace())); tensor.dims(), platform::CPUPlace()));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>( auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place())); pool.Get(tensor.place()));
paddle::platform::GpuMemcpyAsync( paddle::platform::GpuMemcpyAsync(
dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(), dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(),
...@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray( ...@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place); auto *dst = self.mutable_data<T>(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx = auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place)); static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream()); cudaMemcpyHostToDevice, dev_ctx->stream());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册