提交 15e8c80e 编写于 作者: Y Yu Yang 提交者: dzhwinter

Rename API of DeviceContext (#7055)

* Rename API of DeviceContext

Make them as usual names.

* Rename API of DeviceContext

Make them as usual names.

* Fix compile

* Fix compile

* Fix compile

* Fix compile

* Fix compile
上级 c31cbae5
......@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
}
platform::DeviceContextPool::Create(places);
platform::DeviceContextPool::Init(places);
return true;
}
......
......@@ -388,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto dev_ctx = pool.Borrow(place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
......
......@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap,
}
/* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1UL);
CHECK_EQ(config_.inputs_size(), 1);
const NormConfig& conf = config_.inputs(0).norm_conf();
localSize_ = conf.size();
alpha_ = conf.scale();
......
......@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
size_t offset;
if (platform::is_gpu_place(i_tensor.place())) {
......
......@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
}
auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice);
......
......@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase {
out != nullptr,
"The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
......
......@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(dev_place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx);
......
......@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
void CondOp::Run(const Scope& scope, const platform::Place& place) const {
// get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto& dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(place);
PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
......
......@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase {
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(feed_item, place, dev_ctx, out_item);
out_item->set_lod(feed_item.lod());
......
......@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait();
......
......@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase {
out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
math::set_constant(dev_ctx, &out, value);
}
};
......
......@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase {
if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(tensor, place, dev_ctx, &out);
}
}
......
......@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
......
......@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)),
......
......@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
......
......@@ -305,7 +305,7 @@ int main(int argc, char **argv) {
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
......
......@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase {
false /*create_local_scope*/);
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output
......@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase {
auto *program = block->Program();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
......
......@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len;
return out_offset;
......
......@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase {
auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::SerializeToStream(fout, tensor, dev_ctx);
}
......
......@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
......
......@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto &x_lod = x.lod();
auto &mask_dim = mask.dims();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(dev_place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) {
......
......@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp {
if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
CopyFrom(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod());
......@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod());
} else {
......
......@@ -17,7 +17,7 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Borrow(
const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) {
auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) {
......@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow(
return it->second;
}
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto it = device_contexts_.find(place);
if (it != device_contexts_.end()) {
borrowed_contexts.emplace_back(it->second);
} else {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
}
return borrowed_contexts;
}
DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
......
......@@ -109,13 +109,13 @@ class DeviceContextPool {
public:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Get() {
static DeviceContextPool& Instance() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
/*! \brief Create should only called by Init function */
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
......@@ -123,13 +123,7 @@ class DeviceContextPool {
}
/*! \brief Return handle of single device context. */
const platform::DeviceContext* Borrow(const platform::Place& place);
/*! \brief Return handle of multi-device context. */
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places);
~DeviceContextPool() {}
const platform::DeviceContext* Get(const platform::Place& place);
private:
static DeviceContextPool* pool;
......
......@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
using paddle::platform::CPUPlace;
using paddle::platform::CUDAPlace;
DeviceContextPool& pool = DeviceContextPool::Get();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1);
DeviceContextPool& pool = DeviceContextPool::Instance();
auto cpu_dev_ctx1 = pool.Get(CPUPlace());
auto cpu_dev_ctx2 = pool.Get(CPUPlace());
ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1);
std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(CUDAPlace(i));
}
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as CUDAPlace(i)
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
auto dev_ctx = pool.Get(CUDAPlace(i));
ASSERT_NE(dev_ctx, nullptr);
}
}
int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
......@@ -109,7 +94,7 @@ int main(int argc, char** argv) {
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
......
......@@ -144,7 +144,7 @@ int main(int argc, char** argv) {
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places);
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
......
......@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace()));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place()));
pool.Get(tensor.place()));
paddle::platform::GpuMemcpyAsync(
dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(),
......@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place));
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册