未验证 提交 4096ff94 编写于 作者: A Adam Osewski 提交者: GitHub

Small optimizations for conv2d kernel subroutines. (#29188)

- Make sure that oneDNN memory descriptors are created only once at
first iteration.
上级 5c61eeef
...@@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT ...@@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc( const std::string user_key_suffix{"@src_mem_p_user"};
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(), auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
input->format());
return this->AcquireMemoryWithReorder( if (!user_src_mem_p) {
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), auto user_src_md = platform::MKLDNNMemDesc(
"@src_mem_p"); framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p");
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
}
return target_src_mem_p;
}
} }
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
...@@ -324,14 +336,19 @@ class ConvMKLDNNHandlerT ...@@ -324,14 +336,19 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) { const framework::Tensor* bias, const bool is_test) {
const K* bias_data = bias->data<K>(); auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
auto user_bias_md = platform::MKLDNNMemDesc( if (is_test && bias_mem_p) {
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(), return bias_mem_p;
MKLDNNMemoryFormat::x); } else {
const K* bias_data = bias->data<K>();
return this->AcquireMemoryWithReorder( auto user_bias_md = platform::MKLDNNMemDesc(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data), framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
"@bias_mem_p", is_test); MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
"@bias_mem_p", is_test);
}
} }
std::shared_ptr<mkldnn::memory> AcquireResidualMemory( std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
...@@ -340,13 +357,19 @@ class ConvMKLDNNHandlerT ...@@ -340,13 +357,19 @@ class ConvMKLDNNHandlerT
residual_param->type() == framework::DataTypeTrait<T_out>::DataType() residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
? to_void_cast<T_out>(residual_param->data<T_out>()) ? to_void_cast<T_out>(residual_param->data<T_out>())
: to_void_cast<T>(residual_param->data<T>()); : to_void_cast<T>(residual_param->data<T>());
auto user_residual_md = platform::MKLDNNMemDesc( auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
framework::vectorize(residual_param->dims()), if (residual_mem_p) {
framework::ToMKLDNNDataType(residual_param->type()), residual_mem_p->set_data_handle(residual_data);
residual_param->format()); return residual_mem_p;
} else {
auto user_residual_md = platform::MKLDNNMemDesc(
framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()),
residual_param->format());
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data, return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
"@user_residual_data_mem_p"); "@user_residual_data_mem_p");
}
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual( std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册