未验证 提交 b5a8a0d9 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add mlu buffer reader (#40131)

上级 041c4bca
......@@ -70,9 +70,25 @@ BufferedReader::BufferedReader(
stream_ = platform::NpuStreamResourcePool::Instance().New(dev_idx);
}
#endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(place_)) {
int dev_idx = place_.device;
compute_stream_ =
((platform::MLUDeviceContext *)(platform::DeviceContextPool::Instance()
.Get(place_)))
->stream();
events_.resize(buffer_size);
for (auto &event : events_) {
event = platform::MluEventResourcePool::Instance().New(dev_idx);
}
stream_ = platform::MluStreamResourcePool::Instance().New(dev_idx);
}
#endif
cpu_buffer_.resize(buffer_size);
cuda_buffer_.resize(buffer_size);
npu_buffer_.resize(buffer_size);
mlu_buffer_.resize(buffer_size);
ReadTillBufferFullAsync();
}
......@@ -256,6 +272,56 @@ void BufferedReader::ReadAsync(size_t i) {
platform::NPUStreamSync(stream_.get());
}
#endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(place_)) {
TensorVec &mlu = mlu_buffer_[i];
if (mlu.empty()) {
mlu.resize(cpu.size());
} else {
PADDLE_ENFORCE_EQ(
mlu.size(), cpu.size(),
platform::errors::InvalidArgument(
"Input tensor number on MLU and CPU devices are not matched. "
"The number on MLU is %d, on CPU is %d",
mlu.size(), cpu.size()));
}
std::vector<void *> mlu_ptrs;
mlu_ptrs.reserve(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) {
mlu[i].Resize(cpu[i].dims());
mlu[i].set_layout(cpu[i].layout());
mlu_ptrs.emplace_back(mlu[i].mutable_data(place_, cpu[i].type()));
}
platform::SetMLUDeviceId(place_.device);
PADDLE_ENFORCE_MLU_SUCCESS(
cnPlaceNotifier(events_[i].get(), compute_stream_));
PADDLE_ENFORCE_MLU_SUCCESS(cnWaitNotifier(events_[i].get()));
platform::RecordEvent record_event("BufferedReader:MemoryCopy",
platform::TracerEventType::UserDefined,
1);
for (size_t i = 0; i < cpu.size(); ++i) {
auto cpu_place = cpu[i].place();
auto cpu_ptr = cpu[i].data();
auto mlu_ptr = mlu_ptrs[i];
auto size =
cpu[i].numel() * paddle::framework::DataTypeSize(cpu[i].dtype());
if ((platform::is_mlu_place(cpu_place))) {
memory::Copy(place_, mlu_ptr, cpu_place, cpu_ptr, size,
stream_.get());
} else {
memory::Copy(place_, mlu_ptr, cpu_place, cpu_ptr, size,
stream_.get());
platform::MLUStreamSync(stream_.get());
}
mlu[i].set_lod(cpu[i].lod());
}
platform::MLUStreamSync(stream_.get());
}
#endif
return i;
}));
}
......@@ -291,6 +357,8 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
*out = std::move(cuda_buffer_[i]);
} else if (platform::is_npu_place(place_)) {
*out = std::move(npu_buffer_[i]);
} else if (platform::is_mlu_place(place_)) {
*out = std::move(mlu_buffer_[i]);
} else {
*out = std::move(cpu_buffer_[i]);
}
......
......@@ -29,6 +29,11 @@
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/npu/npu_resource_pool.h"
#endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#include "paddle/fluid/platform/device/mlu/mlu_resource_pool.h"
#endif
namespace paddle {
namespace operators {
namespace reader {
......@@ -70,6 +75,7 @@ class BufferedReader : public framework::DecoratedReader {
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> cuda_buffer_;
std::vector<TensorVec> npu_buffer_;
std::vector<TensorVec> mlu_buffer_;
size_t prev_pos_{-1UL};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuStream_t compute_stream_;
......@@ -82,6 +88,12 @@ class BufferedReader : public framework::DecoratedReader {
std::shared_ptr<platform::NpuStreamObject> stream_;
std::vector<std::shared_ptr<platform::NpuEventObject>> events_;
#endif
#ifdef PADDLE_WITH_MLU
mluStream compute_stream_;
std::shared_ptr<platform::MluStreamObject> stream_;
std::vector<std::shared_ptr<platform::MluEventObject>> events_;
#endif
};
} // namespace reader
......
......@@ -80,8 +80,8 @@ void StreamCallbackManager<Stream>::AddCallback(
#endif
#if PADDLE_WITH_MLU
VLOG(3) << "MLULaunchCallback at stream: " << stream_;
LOG(ERROR) << "failed to call MLULaunchCallback, "
VLOG(3) << "MLULaunchCallback at stream: " << stream_
<< " Failed to call MLULaunchCallback, "
<< "because mlu not support StreamAddCallback yet. "
<< "function: " << func;
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册