未验证 提交 4096ff94 编写于 作者: A Adam Osewski 提交者: GitHub

Small optimizations for conv2d kernel subroutines. (#29188)

- Make sure that oneDNN memory descriptors are created only once at
first iteration.
上级 5c61eeef
......@@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
const std::string user_key_suffix{"@src_mem_p_user"};
auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p");
if (!user_src_mem_p) {
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p");
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
}
return target_src_mem_p;
}
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
......@@ -324,14 +336,19 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) {
const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
"@bias_mem_p", is_test);
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
return bias_mem_p;
} else {
const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
"@bias_mem_p", is_test);
}
}
std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
......@@ -340,13 +357,19 @@ class ConvMKLDNNHandlerT
residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
? to_void_cast<T_out>(residual_param->data<T_out>())
: to_void_cast<T>(residual_param->data<T>());
auto user_residual_md = platform::MKLDNNMemDesc(
framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()),
residual_param->format());
auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
if (residual_mem_p) {
residual_mem_p->set_data_handle(residual_data);
return residual_mem_p;
} else {
auto user_residual_md = platform::MKLDNNMemDesc(
framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()),
residual_param->format());
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
"@user_residual_data_mem_p");
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
"@user_residual_data_mem_p");
}
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册