未验证 提交 d1a4c53e 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support default stream (#31510)

上级 fead5631
......@@ -64,6 +64,13 @@ aclFormat ConvertToNpuFormat(DataLayout layout) {
return iter->second;
}
aclrtStream GetCurrentNPUStream() {
int device_id = GetCurrentNPUDeviceId();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(platform::NPUPlace(device_id));
return dev_ctx->stream();
}
NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {
attr_ = aclopCreateAttr();
}
......@@ -249,7 +256,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
auto format = ConvertToNpuFormat(tensor.layout());
auto dims = framework::vectorize(tensor.dims());
VLOG(4) << "dtype:" << dtype << " "
VLOG(4) << "NPU dtype:" << dtype << " "
<< "rank:" << dims.size() << " dims:" << tensor.dims()
<< " format:" << format;
......@@ -264,7 +271,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
void *ptr = tensor.data<void>();
VLOG(4) << "ptr: " << ptr << ", size: " << tensor.memory_size();
VLOG(4) << "NPU ptr: " << ptr << ", size: " << tensor.memory_size();
auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size());
PADDLE_ENFORCE_NOT_NULL(
buffer, platform::errors::External("Call aclCreateDataBuffer failed."));
......@@ -272,11 +279,17 @@ aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
}
void NpuOpRunner::Run(aclrtStream stream) {
if (!stream) {
VLOG(4) << "Run with default current npu stream: " << stream;
stream = GetCurrentNPUStream();
}
VLOG(4) << "op_type: " << op_type_;
VLOG(4) << "input_desc.size: " << input_descs_.size();
VLOG(4) << "output_desc.size: " << output_descs_.size();
VLOG(4) << "stream: " << stream;
VLOG(4) << "attr: " << attr_;
VLOG(4) << "stream: " << stream;
aclError ret = aclopCompileAndExecute(
op_type_.c_str(), input_descs_.size(), input_descs_.data(),
input_buffers_.data(), output_descs_.size(), output_descs_.data(),
......
......@@ -69,7 +69,7 @@ class NpuOpRunner {
std::vector<aclDataBuffer *> &GetOutputBuffers();
void Run(aclrtStream stream);
void Run(aclrtStream stream == nullptrr);
private:
aclTensorDesc *CreateTensorDesc(Tensor tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册