diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 73530eac09e99c695ad8185d694ee9e7a4ed4396..fed6a7dfa5e1ce408d954ce3576bedc7e96b0d35 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -74,7 +74,9 @@ static mkldnn::memory::data_type GetDstType(bool is_int8, bool is_bfloat16, template class ConvMKLDNNHandlerT - : public platform::MKLDNNHandlerT { + : public platform::MKLDNNHandlerT { public: ConvMKLDNNHandlerT(const paddle::framework::ExecutionContext& ctx, const platform::MKLDNNDeviceContext& dev_ctx, @@ -82,11 +84,13 @@ class ConvMKLDNNHandlerT platform::Place cpu_place, const Tensor* input, const Tensor* filter, const Tensor* bias, Tensor* output, const std::string& unique_name) - : platform::MKLDNNHandlerT( + : platform::MKLDNNHandlerT( dev_ctx, mkldnn_engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), unique_name)) { - if (!this->isCached()) { + if (!this->isCachedNonBlocking()) { PADDLE_ENFORCE_EQ( input->layout(), DataLayout::kMKLDNN, platform::errors::InvalidArgument( @@ -224,12 +228,12 @@ class ConvMKLDNNHandlerT auto bias_md = platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); - this->AcquireForwardPrimitiveDescriptor( + this->AcquireForwardPrimitiveDescriptorNonBlocking( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); } else { - this->AcquireForwardPrimitiveDescriptor( + this->AcquireForwardPrimitiveDescriptorNonBlocking( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], mkldnn_paddings[1]); @@ -237,6 +241,142 @@ class ConvMKLDNNHandlerT } } + ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, const Tensor* in, + const Tensor* filter, const Tensor* bias, + const Tensor* out_grad, Tensor* filter_grad, + Tensor* in_x_grad, const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dev_ctx, framework::vectorize(in->dims()), + unique_name)) { + if (!this->isBwdCached()) { + PADDLE_ENFORCE_EQ( + in->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "The input tensor's layout should be %d, but got %d.", + DataLayout::kMKLDNN, in->layout())); + PADDLE_ENFORCE_NE(in->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Got wrong format for Input tensor.")); + + PADDLE_ENFORCE_EQ( + filter->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "The filter tensor's layout should be %d, but got %d.", + DataLayout::kMKLDNN, filter->layout())); + PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Got wrong format for Filter tensor.")); + + PADDLE_ENFORCE_EQ( + out_grad->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "The output_grad tensor's layout should be %d, but got %d.", + DataLayout::kMKLDNN, out_grad->layout())); + PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Wrong format set for output_grad tensor")); + + PADDLE_ENFORCE_EQ( + ctx.Attr("is_test"), false, + platform::errors::InvalidArgument( + "is_test attribute should be set to False in training phase.")); + + std::vector strides_temp = ctx.Attr>("strides"); + std::vector strides(begin(strides_temp), end(strides_temp)); + + std::vector paddings_temp = ctx.Attr>("paddings"); + std::vector paddings(begin(paddings_temp), end(paddings_temp)); + + std::vector dilations_temp = ctx.Attr>("dilations"); + std::vector dilations(begin(dilations_temp), + end(dilations_temp)); + + std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + + int groups = ctx.Attr("groups"); + + auto input_dims = in->dims(); + auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size()); + auto filter_dims = filter->dims(); + auto filter_data_dims = + framework::slice_ddim(filter_dims, 2, filter_dims.size()); + + auto ksize = framework::vectorize(filter_data_dims); + + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + data_dims, strides, ksize); + + auto src_tz = framework::vectorize(in->dims()); + auto weights_tz = framework::vectorize(filter->dims()); + + int g = std::max(groups, 1); + platform::GetGroupConvWeightsTz(weights_tz, g); + auto dst_tz = paddle::framework::vectorize(out_grad->dims()); + + /* create memory descriptor for conv backward without specified format + * ('any') which lets a primitive (conv backward in this case) choose + * the memory format preferred for best performance + */ + const auto chosen_memory_format = MKLDNNMemoryFormat::any; + const auto weights_format = MKLDNNMemoryFormat::any; + + auto src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + const auto dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + auto diff_src_md = platform::MKLDNNMemDesc( + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + auto weights_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), weights_format); + auto diff_weights_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), weights_format); + auto diff_dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + + auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); + std::transform(dilations.begin(), dilations.end(), dilations.begin(), + [](int64_t i) { return i - 1; }); + const mkldnn::memory::dims dilations_dims = dilations; + + const mkldnn::memory::dims stride_dims = strides; + // Recreating FWD PD. For training there are no post ops in convolution + mkldnn::primitive_attr conv_attr; + if (bias) { + auto bias_tz = framework::vectorize(bias->dims()); + auto bias_md = platform::MKLDNNMemDesc( + bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x); + + this->AcquireForwardPrimitiveDescriptorNonBlocking( + conv_attr, mkldnn::prop_kind::forward_training, + dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md, + dst_md, stride_dims, dilations_dims, mkldnn_paddings[0], + mkldnn_paddings[1]); + } else { + this->AcquireForwardPrimitiveDescriptorNonBlocking( + conv_attr, mkldnn::prop_kind::forward_training, + dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md, + stride_dims, dilations_dims, mkldnn_paddings[0], + mkldnn_paddings[1]); + } + + this->AcquireBackwardPrimitiveDescriptorNonBlocking( + mkldnn::algorithm::convolution_direct, diff_src_md, weights_md, + diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], + mkldnn_paddings[1]); + + this->AcquireBackwardWeightsPrimitiveDescriptorNonBlocking( + mkldnn::algorithm::convolution_direct, src_md, diff_weights_md, + diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], + mkldnn_paddings[1]); + } + } + mkldnn::primitive_attr CreatePostOps( std::string fuse_activation, float fuse_alpha, float fuse_beta, bool fuse_residual_conn, const std::vector output_shift_scale = {}, @@ -280,27 +420,75 @@ class ConvMKLDNNHandlerT return conv_attr; } + std::shared_ptr + AcquireWeightsMemoryWithReorderFromDataPrimitive( + const framework::Tensor* filter, const int groups, const bool is_conv3d) { + const K* filter_data = filter->data(); + auto weights_tz = framework::vectorize(filter->dims()); + platform::GetGroupConvWeightsTz(weights_tz, groups); + + auto user_src_md = platform::MKLDNNMemDesc( + weights_tz, platform::MKLDNNGetDataType(), + GetWeightsFormat(filter->format(), groups, is_conv3d)); + + return this->AcquireMemoryWithReorder( + user_src_md, this->bwd_pd_->weights_desc(), + to_void_cast(filter_data), "@weights_mem_d_p", false); + } + std::shared_ptr AcquireSrcMemoryWithReorder( const framework::Tensor* input) { - const T* input_data = input->data(); - const std::string user_key_suffix{"@src_mem_p_user"}; - auto user_src_mem_p = this->AcquireMemory(user_key_suffix); + return this->AcquireMemoryWithReorderPrimitive( + input, "@src_mem_p_user", "@src_mem_p_target", "@src_mem_p", + this->fwd_pd_->src_desc()); + } - if (!user_src_mem_p) { - auto user_src_md = platform::MKLDNNMemDesc( - framework::vectorize(input->dims()), platform::MKLDNNGetDataType(), - input->format()); + std::shared_ptr + AcquireSrcMemoryWithReorderFromWeightsPrimitive( + const framework::Tensor* input) { + return this->AcquireMemoryWithReorderPrimitive( + input, "@src_mem_w_p_user", "@src_mem_w_p_target", "@src_mem_w_p", + this->bwd_w_pd_->src_desc()); + } + + std::shared_ptr + AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( + const framework::Tensor* out_grad) { + return this->AcquireMemoryWithReorderPrimitive( + out_grad, "@diff_dst_mem_w_p_user", "@diff_dst_mem_w_p_target", + "@diff_dst_mem_w_p", this->bwd_w_pd_->diff_dst_desc()); + } + + std::shared_ptr + AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( + const framework::Tensor* out_grad) { + return this->AcquireMemoryWithReorderPrimitive( + out_grad, "@diff_dst_mem_p_user", "@diff_dst_mem_p_target", + "@diff_dst_mem_p", this->bwd_pd_->diff_dst_desc()); + } + + std::shared_ptr AcquireMemoryWithReorderPrimitive( + const framework::Tensor* in_mem, const char* key_mem_user, + const char* key_mem_target, const char* key_mem, + const mkldnn::memory::desc& mem_md) { + const T* in_mem_data = in_mem->data(); + const std::string user_key_suffix{key_mem_user}; + auto user_mem_p = this->AcquireMemory(user_key_suffix); + + if (!user_mem_p) { + auto user_mem_md = platform::MKLDNNMemDesc( + framework::vectorize(in_mem->dims()), + platform::MKLDNNGetDataType(), in_mem->format()); return this->AcquireMemoryWithReorder( - user_src_md, this->fwd_pd_->src_desc(), to_void_cast(input_data), - "@src_mem_p"); + user_mem_md, mem_md, to_void_cast(in_mem_data), key_mem); } else { - const std::string target_key_suffix{"@src_mem_p_target"}; - const auto target_src_mem_p = this->AcquireMemory(target_key_suffix); - user_src_mem_p->set_data_handle(to_void_cast(input_data)); - if (user_src_mem_p != target_src_mem_p) { - this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p"); + const std::string target_key_suffix{key_mem_target}; + const auto target_mem_p = this->AcquireMemory(target_key_suffix); + user_mem_p->set_data_handle(to_void_cast(in_mem_data)); + if (user_mem_p != target_mem_p) { + this->AcquireReorder(user_mem_p, target_mem_p, key_mem); } - return target_src_mem_p; + return target_mem_p; } } @@ -866,7 +1054,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } }; -template +template class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { @@ -879,189 +1067,44 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { const Tensor* input = ctx.Input("Input"); const Tensor* filter = ctx.Input("Filter"); + const Tensor* bias = + ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; const Tensor* output_grad = ctx.Input(framework::GradVarName("Output")); Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); - PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The input tensor's layout should be %d, but got %d.", - DataLayout::kMKLDNN, input->layout())); - PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Input tensor.")); - - PADDLE_ENFORCE_EQ( - filter->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The filter tensor's layout should be %d, but got %d.", - DataLayout::kMKLDNN, filter->layout())); - PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Got wrong format for Filter tensor.")); - - PADDLE_ENFORCE_EQ( - output_grad->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "The output_grad tensor's layout should be %d, but got %d.", - DataLayout::kMKLDNN, output_grad->layout())); - PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for output_grad tensor")); - - PADDLE_ENFORCE_EQ( - ctx.Attr("is_test"), false, - platform::errors::InvalidArgument( - "is_test attribute should be set to False in training phase.")); - if (!input_grad && !filter_grad) return; - std::vector strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - std::vector paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - std::vector dilations_temp = ctx.Attr>("dilations"); - std::vector dilations(begin(dilations_temp), end(dilations_temp)); - - std::string padding_algorithm = ctx.Attr("padding_algorithm"); - - int groups = ctx.Attr("groups"); - - bool is_conv3d = strides.size() == 3U; - const T* input_data = input->data(); - const T* filter_data = filter->data(); - const T* output_grad_data = output_grad->data(); - T* input_grad_data = nullptr; - T* filter_grad_data = nullptr; - - auto input_dims = input->dims(); - auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size()); - auto filter_dims = filter->dims(); - auto filter_data_dims = - framework::slice_ddim(filter_dims, 2, filter_dims.size()); - - auto ksize = framework::vectorize(filter_data_dims); - - UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, - data_dims, strides, ksize); - - auto src_tz = paddle::framework::vectorize(input->dims()); - auto weights_tz = paddle::framework::vectorize(filter->dims()); - - int g = std::max(groups, 1); - platform::GetGroupConvWeightsTz(weights_tz, g); - auto dst_tz = paddle::framework::vectorize(output_grad->dims()); - - auto src_format = input->format(); - MKLDNNMemoryFormat weights_format = - GetWeightsFormat(filter->format(), g, is_conv3d); - - // Get an unique name from "argument" name of "input" and "Filter" variable - // as well as attributes of primitive to be created - // This name will be used as key when saving info into device context - std::string key = platform::CreateKey( - dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); - - const std::string key_conv_pd = key + "@fwd_pd"; - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - std::vector pipeline; - - // Create user memory descriptors - auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, platform::MKLDNNGetDataType(), src_format); - auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), weights_format); - auto user_diff_dst_md = platform::MKLDNNMemDesc( - {dst_tz}, platform::MKLDNNGetDataType(), output_grad->format()); - - /* create memory descriptor for conv backward without specified format - * ('any') which lets a primitive (conv backward in this case) choose - * the memory format preferred for best performance - */ - auto chosen_memory_format = MKLDNNMemoryFormat::any; - weights_format = MKLDNNMemoryFormat::any; - - auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto diff_src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), weights_format); - auto diff_weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), weights_format); - auto diff_dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - // Retrieve conv_pd from device context - auto conv_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_conv_pd)); - PADDLE_ENFORCE_NE(conv_pd, nullptr, - platform::errors::InvalidArgument( - "Fail to find conv_pd in device context")); - - auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); - std::transform(dilations.begin(), dilations.end(), dilations.begin(), - [](int64_t i) { return i - 1; }); - const mkldnn::memory::dims dilations_dims = dilations; - // create backward convolution weights primitive descriptor - auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc( - mkldnn::algorithm::convolution_direct, src_md, diff_weights_md, - diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], - mkldnn_paddings[1]); - - auto conv_bwd_weights_pd = - std::make_shared( - conv_bwd_weights_desc, mkldnn_engine, *conv_pd); - - // create backward convolution data primitive descriptor - auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc( - mkldnn::algorithm::convolution_direct, diff_src_md, weights_md, - diff_dst_md, strides, dilations_dims, mkldnn_paddings[0], - mkldnn_paddings[1]); - - auto conv_bwd_data_pd = - std::make_shared( - conv_bwd_data_desc, mkldnn_engine, *conv_pd); - - platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd, - conv_bwd_weights_pd, dev_ctx, - mkldnn_engine, key); + // TODO(jczaja): Are all tensors really needed? + ConvMKLDNNHandlerT handler( + ctx, dev_ctx, ctx.GetPlace(), input, filter, bias, output_grad, + filter_grad, input_grad, + ctx.InputName("Input") + ctx.InputName("Filter")); // create mkldnn memory from input tensors (data/weights) - auto user_src_memory_p = - handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data)); - auto user_weights_memory_p = handler.AcquireWeightsMemory( - user_weights_md, to_void_cast(filter_data)); - auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory( - user_diff_dst_md, to_void_cast(output_grad_data)); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if (filter_grad) { - auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive( - user_src_memory_p, pipeline); - - auto diff_dst_memory_4filter_p = - handler.AcquireDiffDstMemoryFromWeightsPrimitive( - user_diff_dst_memory_p, pipeline); - const size_t size = handler.GetDiffWeightsMemorySize(); - filter_grad_data = filter_grad->mutable_data(ctx.GetPlace(), size); + if (filter_grad) { + auto src_memory_p = + handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input); + auto diff_dst_memory_p = + handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive( + output_grad); // For convoluition with groups write filter grad into // oneDNN buffer and then we reorder it into filter_grad tensor + int g = std::max(ctx.Attr("groups"), 1); auto diff_weights_memory_p = - g > 1 ? handler.AcquireDiffWeightsMemoryFromWeightsPrimitive() - : handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( - reinterpret_cast(filter_grad_data)); + g > 1 ? handler.AcquireDiffWeightsMemory() + : handler.AcquireDiffWeightsMemory(filter_grad); - auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights(); + auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive(); // TODO(grygielski) why no bias_diff? conv_bwd_weights_p->execute( astream, {{MKLDNN_ARG_SRC, *src_memory_p}, - {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4filter_p}, + {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p}, {MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); astream.wait(); @@ -1073,10 +1116,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // For convolution with groups convert from blocked to NCHW // otherwise there will be problems in next operators working on this data if (g > 1) { - memory::data_type in_type = - framework::ToMKLDNNDataType(filter_grad->type()); + memory::data_type in_type = framework::ToMKLDNNDataType(filter->type()); // for 3d conv with groups (six dimensional data reorder to goidhw) // for 2d conv with groups (five dimensional data reorder to goihw) + // auto weights_tz = paddle::framework::vectorize(filter->dims()); + + auto weights_tz = diff_weights_memory_p->get_desc().dims(); mkldnn::memory::format_tag out_format = weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw : mkldnn::memory::format_tag::goihw; @@ -1084,9 +1129,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { out_format, in_type); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(), - in_type, dev_ctx, mkldnn_engine, - key); + platform::ReorderMKLDNNHandler handler( + weights_tz, filter->type(), in_type, dev_ctx, mkldnn_engine, key); auto reorder_dst_memory_p = handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace()); @@ -1113,24 +1157,21 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { } } if (input_grad) { - auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive( - user_weights_memory_p, pipeline); - - auto diff_dst_memory_4data_p = - handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, - pipeline); - - const size_t size = handler.GetDiffSourceMemorySize(); - input_grad_data = input_grad->mutable_data(ctx.GetPlace(), size); + auto weights_memory_p = + handler.AcquireWeightsMemoryWithReorderFromDataPrimitive( + filter, ctx.Attr("groups"), + ctx.Attr>("strides").size() == 3U); - auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( - reinterpret_cast(input_grad_data)); + auto diff_dst_memory_p = + handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive( + output_grad); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad); - auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData(); + auto conv_bwd_data_p = handler.AcquireBackwardPrimitive(); conv_bwd_data_p->execute(astream, {{MKLDNN_ARG_WEIGHTS, *weights_memory_p}, - {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4data_p}, + {MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p}, {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}}); astream.wait(); @@ -1167,7 +1208,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); + ops::ConvMKLDNNGradOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN, ::paddle::platform::CPUPlace, FP32, @@ -1177,4 +1218,4 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNGradOpKernel); + ops::ConvMKLDNNGradOpKernel); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index d6563be48fe484cae5c54c52c87e5c3a1493e584..2981e5502ce6ac2d5cf55e8bf60a30035f032a3a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -35,7 +35,8 @@ using user_function = std::function(const float*)>; using memory = mkldnn::memory; template + typename TBackward = mkldnn_dummy_primitive, + typename TBackward_params = mkldnn_dummy_primitive> class MKLDNNHandlerT { public: MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, @@ -72,6 +73,21 @@ class MKLDNNHandlerT { return backward_p; } + std::shared_ptr AcquireBackwardWeightsPrimitive() { + const std::string key_p = key_ + "@bwd_w_p"; + auto backward_p = + std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); + if (backward_p == nullptr) { + PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable( + "Error: BWD_PD should be set when " + "getting BWD prim witk key: %s .", + key_p)); + backward_p = std::make_shared(*bwd_w_pd_); + dev_ctx_.SetBlob(key_p, backward_p); + } + return backward_p; + } + std::shared_ptr AcquireSrcMemory( const framework::Tensor* input) { const T* input_data = input->data(); @@ -116,6 +132,29 @@ class MKLDNNHandlerT { "@diff_src_mem_p"); } + // Buffer of given Tensor is used for oneDNN computation + std::shared_ptr AcquireDiffWeightsMemory( + framework::Tensor* diff_weights) { + PADDLE_ENFORCE_NOT_NULL( + bwd_w_pd_, + platform::errors::Unavailable( + "Error: BWD_W_PD should be set when getting BWD grad of weights.")); + T* ptr = diff_weights->mutable_data( + place_, bwd_w_pd_->diff_weights_desc().get_size()); + return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr, + "@diff_wei_mem_p"); + } + + // Buffer is allocated by oneDNN to store computation results + std::shared_ptr AcquireDiffWeightsMemory(void) { + PADDLE_ENFORCE_NOT_NULL( + bwd_w_pd_, + platform::errors::Unavailable( + "Error: BWD_W_PD should be set when getting BWD grad of weights.")); + return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), + "@diff_wei_mem_p"); + } + protected: bool isCached() { const std::string key_pd = key_common_ + "@fwd_pd"; @@ -243,6 +282,27 @@ class MKLDNNHandlerT { } } + template + void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) { + // fwd_pd_ is set during grad by calling + // AcquireForwardPrimitiveDescriptorNonBlocking + PADDLE_ENFORCE_NOT_NULL( + fwd_pd_, + platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.", + key_ + "@fwd_pd")); + const std::string key_pd = key_ + "@bwd_w_pd"; + bwd_w_pd_ = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_pd)); + if (bwd_w_pd_ == nullptr) { + auto bwd_desc = + typename TBackward_params::desc(std::forward(args)...); + bwd_w_pd_ = std::make_shared( + bwd_desc, engine_, *fwd_pd_); + dev_ctx_.SetBlob(key_pd, bwd_w_pd_); + } + } + std::shared_ptr AcquireMemoryFromPrimitive( const std::string& suffix) { return std::static_pointer_cast( @@ -370,6 +430,7 @@ class MKLDNNHandlerT { std::string key_; std::shared_ptr fwd_pd_; std::shared_ptr bwd_pd_; + std::shared_ptr bwd_w_pd_; }; // TODO(grygielski) this class will be deleted later.