未验证 提交 d1a4c53e 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support default stream (#31510)

上级 fead5631
...@@ -64,6 +64,13 @@ aclFormat ConvertToNpuFormat(DataLayout layout) { ...@@ -64,6 +64,13 @@ aclFormat ConvertToNpuFormat(DataLayout layout) {
return iter->second; return iter->second;
} }
aclrtStream GetCurrentNPUStream() {
int device_id = GetCurrentNPUDeviceId();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(platform::NPUPlace(device_id));
return dev_ctx->stream();
}
NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) { NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {
attr_ = aclopCreateAttr(); attr_ = aclopCreateAttr();
} }
...@@ -249,7 +256,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { ...@@ -249,7 +256,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
auto format = ConvertToNpuFormat(tensor.layout()); auto format = ConvertToNpuFormat(tensor.layout());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
VLOG(4) << "dtype:" << dtype << " " VLOG(4) << "NPU dtype:" << dtype << " "
<< "rank:" << dims.size() << " dims:" << tensor.dims() << "rank:" << dims.size() << " dims:" << tensor.dims()
<< " format:" << format; << " format:" << format;
...@@ -264,7 +271,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) { ...@@ -264,7 +271,7 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) { aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
void *ptr = tensor.data<void>(); void *ptr = tensor.data<void>();
VLOG(4) << "ptr: " << ptr << ", size: " << tensor.memory_size(); VLOG(4) << "NPU ptr: " << ptr << ", size: " << tensor.memory_size();
auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size()); auto *buffer = aclCreateDataBuffer(ptr, tensor.memory_size());
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
buffer, platform::errors::External("Call aclCreateDataBuffer failed.")); buffer, platform::errors::External("Call aclCreateDataBuffer failed."));
...@@ -272,11 +279,17 @@ aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) { ...@@ -272,11 +279,17 @@ aclDataBuffer *NpuOpRunner::CreateDataBuffer(Tensor tensor) {
} }
void NpuOpRunner::Run(aclrtStream stream) { void NpuOpRunner::Run(aclrtStream stream) {
if (!stream) {
VLOG(4) << "Run with default current npu stream: " << stream;
stream = GetCurrentNPUStream();
}
VLOG(4) << "op_type: " << op_type_; VLOG(4) << "op_type: " << op_type_;
VLOG(4) << "input_desc.size: " << input_descs_.size(); VLOG(4) << "input_desc.size: " << input_descs_.size();
VLOG(4) << "output_desc.size: " << output_descs_.size(); VLOG(4) << "output_desc.size: " << output_descs_.size();
VLOG(4) << "stream: " << stream;
VLOG(4) << "attr: " << attr_; VLOG(4) << "attr: " << attr_;
VLOG(4) << "stream: " << stream;
aclError ret = aclopCompileAndExecute( aclError ret = aclopCompileAndExecute(
op_type_.c_str(), input_descs_.size(), input_descs_.data(), op_type_.c_str(), input_descs_.size(), input_descs_.data(),
input_buffers_.data(), output_descs_.size(), output_descs_.data(), input_buffers_.data(), output_descs_.size(), output_descs_.data(),
......
...@@ -69,7 +69,7 @@ class NpuOpRunner { ...@@ -69,7 +69,7 @@ class NpuOpRunner {
std::vector<aclDataBuffer *> &GetOutputBuffers(); std::vector<aclDataBuffer *> &GetOutputBuffers();
void Run(aclrtStream stream); void Run(aclrtStream stream == nullptrr);
private: private:
aclTensorDesc *CreateTensorDesc(Tensor tensor); aclTensorDesc *CreateTensorDesc(Tensor tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册