提交 f11934cb 编写于 作者: T Tomasz Patejko

MKLDNN conv residual data: residual data is reorder when formats are incorrect

上级 62a0fe08
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -108,6 +110,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -108,6 +110,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
"@data-weights_mem_p", pipeline); "@data-weights_mem_p", pipeline);
} }
std::shared_ptr<mkldnn::memory> AcquireResidualDataMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) { void* ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
...@@ -386,7 +393,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -386,7 +393,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
T* output_data = nullptr; // create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
...@@ -399,20 +414,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -399,20 +414,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
output->ShareDataWith(*residual_param); if (residual_param->format() != output->format()) {
output_data = output->mutable_data<T>(ctx.GetPlace()); auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
platform::Reorder(*user_residual_memory_p, *dst_memory_p);
} else { } else {
output_data = output->ShareDataWith(*residual_param);
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); }
} }
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册