未验证 提交 4096ff94 编写于 作者: A Adam Osewski 提交者: GitHub

Small optimizations for conv2d kernel subroutines. (#29188)

- Make sure that oneDNN memory descriptors are created only once at
first iteration.
上级 5c61eeef
...@@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT ...@@ -290,13 +290,25 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const std::string user_key_suffix{"@src_mem_p_user"};
auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_src_mem_p) {
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(), framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format()); input->format());
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p"); "@src_mem_p");
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
}
return target_src_mem_p;
}
} }
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
...@@ -324,6 +336,10 @@ class ConvMKLDNNHandlerT ...@@ -324,6 +336,10 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) { const framework::Tensor* bias, const bool is_test) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) {
return bias_mem_p;
} else {
const K* bias_data = bias->data<K>(); const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(), framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
...@@ -333,6 +349,7 @@ class ConvMKLDNNHandlerT ...@@ -333,6 +349,7 @@ class ConvMKLDNNHandlerT
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data), user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data),
"@bias_mem_p", is_test); "@bias_mem_p", is_test);
} }
}
std::shared_ptr<mkldnn::memory> AcquireResidualMemory( std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
const framework::Tensor* residual_param) { const framework::Tensor* residual_param) {
...@@ -340,6 +357,11 @@ class ConvMKLDNNHandlerT ...@@ -340,6 +357,11 @@ class ConvMKLDNNHandlerT
residual_param->type() == framework::DataTypeTrait<T_out>::DataType() residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
? to_void_cast<T_out>(residual_param->data<T_out>()) ? to_void_cast<T_out>(residual_param->data<T_out>())
: to_void_cast<T>(residual_param->data<T>()); : to_void_cast<T>(residual_param->data<T>());
auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
if (residual_mem_p) {
residual_mem_p->set_data_handle(residual_data);
return residual_mem_p;
} else {
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
framework::vectorize(residual_param->dims()), framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()), framework::ToMKLDNNDataType(residual_param->type()),
...@@ -348,6 +370,7 @@ class ConvMKLDNNHandlerT ...@@ -348,6 +370,7 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data, return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
"@user_residual_data_mem_p"); "@user_residual_data_mem_p");
} }
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual( std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
framework::Tensor* output, const framework::Tensor* residual_param) { framework::Tensor* output, const framework::Tensor* residual_param) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册