未验证 提交 bd0b38e6 编写于 作者: A Adam 提交者: GitHub

Refactor of conv fp32 oneDNN operator (#25137)

* Refactor of conv fp32 oneDNN operator
test=develop

* Formatting fix
test=develop

* Return Enforces
test=develop

* GetWeights improvements
test=develop
上级 b2f5a149
...@@ -210,6 +210,73 @@ class MKLDNNHandlerT { ...@@ -210,6 +210,73 @@ class MKLDNNHandlerT {
return mem_p; return mem_p;
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix) {
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
mkldnn::stream astream(engine_);
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) { std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>( return std::static_pointer_cast<mkldnn::memory>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册