From 4cf499c0ab08abfd8f278e0eaaa18d02fb435cf4 Mon Sep 17 00:00:00 2001 From: bingyanghuang <33643817+bingyanghuang@users.noreply.github.com> Date: Fri, 18 Oct 2019 13:38:54 +0800 Subject: [PATCH] cherry-pick PR#20640 to release 1.6, test=release/1.6 (#20706) --- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 453 ++++++++++-------- paddle/fluid/platform/mkldnn_reuse.h | 50 ++ 2 files changed, 303 insertions(+), 200 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index a9fc17ce89..20f5ffa4d5 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -29,32 +29,34 @@ using mkldnn::stream; using platform::to_void_cast; using platform::GetMKLDNNFormat; -constexpr int same_scale_mask = 0; -constexpr int o_slice_mask = 1 << 0; // 1 -constexpr int g_slice_mask = 1 << 1; // 2 -constexpr int g_o_slice_mask = g_slice_mask | o_slice_mask; // 3 - -static int ComputeMask(bool is_multi_channel, int multi_channel_mask) { - return is_multi_channel ? multi_channel_mask : same_scale_mask; -} - -static int ComputeWeightsMask(int is_multi_channel, int g) { - int multi_channel_mask = g > 1 ? g_o_slice_mask : o_slice_mask; - return ComputeMask(is_multi_channel, multi_channel_mask); -} - -static int ComputeBiasMask(int is_multi_channel) { - return ComputeMask(is_multi_channel, o_slice_mask); -} - -inline void GetWeightsTz(std::vector& weights_tz, int groups) { // NOLINT +inline void GetWeightsTz(std::vector& weights_tz, int groups, // NOLINT + bool is_conv3d) { if (groups > 1) { - // if (is_conv3d) [o, i, dimension, h, w]->[g, o/g, i, dimension, h, w] - // else [o, i, h, w] -> [g, o/g, i, h, w] - weights_tz.push_back(0); - std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end()); - weights_tz[0] = groups; - weights_tz[1] = weights_tz[1] / groups; + if (is_conv3d) { + int output = weights_tz[0]; + int input = weights_tz[1]; + int dimension = weights_tz[2]; + int height = weights_tz[3]; + int width = weights_tz[4]; + weights_tz.resize(6); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = dimension; + weights_tz[4] = height; + weights_tz[5] = width; + } else { + int output = weights_tz[0]; + int input = weights_tz[1]; + int height = weights_tz[2]; + int width = weights_tz[3]; + weights_tz.resize(5); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = height; + weights_tz[4] = width; + } } } @@ -67,59 +69,28 @@ inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format, } } -static std::vector ComputeOutputShiftScale( - const float scale_out_data, const float scale_in_data, - const std::vector& scale_weights_data) { - int count = scale_weights_data.size(); - std::vector output_shift_scale(count); -#pragma omp parallel for - for (int i = 0; i < count; i++) { - if (scale_weights_data[i] == 0.0) { - output_shift_scale[i] = scale_out_data; - } else { - output_shift_scale[i] = - static_cast(static_cast(scale_out_data) / - (static_cast(scale_in_data) * - static_cast(scale_weights_data[i]))); - } - } - return output_shift_scale; -} - -static std::vector ComputeBiasScale( - const float scale_in_data, const std::vector& scale_weights_data) { - int count = scale_weights_data.size(); - std::vector scale_bias_data(count); -#pragma omp parallel for if (count > 1) - for (int i = 0; i < count; i++) { - scale_bias_data[i] = scale_in_data * scale_weights_data[i]; - } - return scale_bias_data; -} - static mkldnn::memory::data_type GetDstType(bool is_int8, bool force_fp32_output, std::string fuse_activation, bool fuse_residual_conn, const Tensor* residual_param) { auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float - if (is_int8 && !force_fp32_output) { + if (is_int8) { + dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") + ? mkldnn::memory::data_type::u8 + : mkldnn::memory::data_type::s8; + if (force_fp32_output) { + dst_dt = mkldnn::memory::data_type::f32; + } if (fuse_residual_conn && residual_param) { - // when residual exists, dst_dt will follow the residual_param type, - // but output will to be set to u8 if relu exists auto residual_dt = framework::ToMKLDNNDataType(residual_param->type()); - dst_dt = residual_dt; - } else { - // when residual does not exist, if (b)relu exist s8 else s8 - dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") - ? mkldnn::memory::data_type::u8 - : mkldnn::memory::data_type::s8; + if (dst_dt != residual_dt) dst_dt = residual_dt; } } return dst_dt; } -template +template class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { @@ -215,7 +186,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_tz = paddle::framework::vectorize(input->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims()); int g = std::max(groups, 1); - GetWeightsTz(weights_tz, g); + GetWeightsTz(weights_tz, g, is_conv3d); auto dst_tz = paddle::framework::vectorize(output->dims()); // Get unique name for storing MKLDNN primitives @@ -297,8 +268,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto residual_param = ctx.Input("ResidualData"); auto residual_param_data = residual_param->data(); - PADDLE_ENFORCE( - residual_param_data != nullptr, + PADDLE_ENFORCE_NE( + residual_param_data, nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion"); PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), "Output and elementwise parameter need to have the " @@ -358,7 +329,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); } - template void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { const bool is_test = ctx.Attr("is_test"); @@ -417,11 +387,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bool force_fp32_output = ctx.Attr("force_fp32_output"); bool unsigned_output = (fuse_activation == "relu" || fuse_activation == "relu6"); - auto scale_in_data = ctx.Attr("Scale_in"); - auto scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); - auto scale_weights_data = ctx.Attr>("Scale_weights"); - auto scale_out_data = - force_fp32_output ? 1.0f : ctx.Attr("Scale_out"); PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output, "residual fusion does not support force output with fp32"); @@ -442,7 +407,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_tz = paddle::framework::vectorize(input->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims()); int g = std::max(groups, 1); - GetWeightsTz(weights_tz, g); + + GetWeightsTz(weights_tz, g, is_conv3d); auto dst_tz = paddle::framework::vectorize(output->dims()); mkldnn::memory::data_type src_dt = @@ -451,143 +417,229 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::string key = platform::CreateKey( src_tz, src_dt, ctx.op().Input("Input") + ctx.op().Input("Filter")); + const std::string key_conv_pd = key + "@conv_pd"; + + bool need_s8_to_u8 = false; std::shared_ptr conv_p; std::shared_ptr src_memory_p; std::shared_ptr user_src_memory_p; + std::shared_ptr dst_memory_p; std::vector pipeline; std::shared_ptr conv_pd; - std::shared_ptr dst_memory_p, user_residual_memory_p; - - const float* filter_data = filter->data(); - bool is_multi_channel = scale_weights_data.size() > 1; - - auto output_shift_scale = ComputeOutputShiftScale( - scale_out_data, scale_in_data, scale_weights_data); - - float scale_residual = - fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; - auto user_src_md = - platform::MKLDNNMemDesc({src_tz}, src_dt, input->format()); - auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), - ((g) == 1) ? mkldnn::memory::format::oihw - : mkldnn::memory::format::goihw); - - /* create memory descriptor for convolution without specified format - * ('any') which lets a primitive (convolution in this case) choose - * the memory format preferred for best performance - */ - std::string data_format = ctx.Attr("data_format"); - auto chosen_memory_format = - platform::data_format_to_memory_format(data_format); - - auto src_md = platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); - auto weights_md = platform::MKLDNNMemDesc(weights_tz, memory::data_type::s8, - chosen_memory_format); - auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - - platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); - auto propagation = is_test ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training; - - std::vector bias_tz; - - if (bias) { - bias_tz = paddle::framework::vectorize(bias->dims()); - auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, - mkldnn::memory::format::x); - conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( - src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, - propagation, output_shift_scale, scale_residual); - } else { - conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( - src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, - fuse_residual_conn, propagation, output_shift_scale, scale_residual); + std::shared_ptr handler; + + // This is workaround for hacky implementation + // of conv int8 mkl-dnn. Once conv fp32 and conv int8 + // are merged/unified, this will disappear + std::string key_tid = ""; + if (platform::get_cur_mkldnn_session_id() == + platform::kMKLDNNSessionID_Default) { + key_tid = "-t:" + platform::ThreadIDasStr(); } - // create mkldnn memory from input tensors (data/weights) - user_src_memory_p = - handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data)); - auto user_weights_memory_p = handler.AcquireWeightsMemory( - user_weights_md, to_void_cast(filter_data)); - - // create reorder primitive if the input format is not the preferred one - src_memory_p = - handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); - - std::shared_ptr weights_memory_p; + auto prim_key = key + key_tid + "@conv_p"; + auto dst_key = key + key_tid + "@dst_mem_p"; + auto src_key = key + key_tid + "@src_mem_p"; + auto user_src_key = key + key_tid + "@user_src_mem_p"; + auto src_reorder_key = key + key_tid + "@src_mem_preorder_p"; + auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p"; + + conv_p = std::static_pointer_cast( + dev_ctx.GetBlob(prim_key)); + + if (conv_p == nullptr || !is_test) { + const K* filter_data = filter->data(); + auto scale_in_data = ctx.Attr("Scale_in"); + auto scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); + auto scale_weights_data = ctx.Attr>("Scale_weights"); + auto scale_out_data = + force_fp32_output ? 1.0f : ctx.Attr("Scale_out"); + float sum_scale = + fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; + + bool is_multi_channel = scale_weights_data.size() > 1; + + int count = is_multi_channel ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0] + : (weights_tz)[0]) + : 1; + std::vector output_shift_scale(count); +#pragma omp parallel for if (count > 1) + for (int i = 0; i < count; i++) { + if (scale_weights_data[i] == 0.0) + output_shift_scale[i] = + scale_out_data; // weights data will contain 0 + // in some models, then weights + // scale couldn't be calculated + else + output_shift_scale[i] = + static_cast(static_cast(scale_out_data) / + (static_cast(scale_in_data) * + static_cast(scale_weights_data[i]))); + } - int mask_reorder = ComputeWeightsMask(is_multi_channel, g); + auto user_src_md = + platform::MKLDNNMemDesc({src_tz}, src_dt, input->format()); + auto user_weights_md = platform::MKLDNNMemDesc( + {weights_tz}, platform::MKLDNNGetDataType(), + ((g) == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw); + + /* create memory descriptor for convolution without specified format + * ('any') which lets a primitive (convolution in this case) choose + * the memory format preferred for best performance + */ + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + + std::vector bias_tz; + + auto src_md = + platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format); + auto weights_md = platform::MKLDNNMemDesc( + weights_tz, memory::data_type::s8, chosen_memory_format); + auto dst_md = platform::MKLDNNMemDesc( + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + + handler.reset( + new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key)); + // create a conv primitive descriptor and save it for usage in backward + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training; - weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline, is_test, true, scale_weights_data, - mask_reorder); + if (bias) { + bias_tz = paddle::framework::vectorize(bias->dims()); + auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, + MKLDNNMemoryFormat::x); + conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, bias_md, dst_md, strides, paddings, + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, propagation, output_shift_scale, sum_scale); + } else { + conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, boost::none, dst_md, strides, paddings, + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, propagation, output_shift_scale, sum_scale); + } - if (fuse_residual_conn) { - auto residual_param = ctx.Input("ResidualData"); - auto residual_param_data = residual_param->data(); - PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), - "Output and elementwise parameter need to have the " - "same dimension sizes"); - auto residual_dt = - paddle::framework::ToMKLDNNDataType(residual_param->type()); - if (residual_param->format() != handler.GetDstFormat()) { - auto residual_data_tz = - paddle::framework::vectorize(residual_param->dims()); - auto user_residual_md = platform::MKLDNNMemDesc( - residual_data_tz, residual_dt, residual_param->format()); + // create mkldnn memory from input tensors (data/weights) + user_src_memory_p = + handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); + auto user_weights_memory_p = handler->AcquireWeightsMemory( + user_weights_md, to_void_cast(filter_data)); + + // create reorder primitive if the input format is not the preferred one + src_memory_p = + handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); + + std::shared_ptr weights_memory_p; + int mask_reorder = + is_multi_channel ? ((g != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; + weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive( + user_weights_memory_p, pipeline, is_test, true, scale_weights_data, + mask_reorder); + + if (fuse_residual_conn) { + auto residual_param = ctx.Input("ResidualData"); + PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), + "Output and elementwise parameter need to have the " + "same dimension sizes"); + auto residual_dt = + paddle::framework::ToMKLDNNDataType(residual_param->type()); + if (residual_param->format() != handler->GetDstFormat()) { + auto residual_data_tz = + paddle::framework::vectorize(residual_param->dims()); + auto user_residual_md = platform::MKLDNNMemDesc( + residual_data_tz, residual_dt, residual_param->format()); + dst_memory_p = platform::SetDstMemory( + ctx, output, residual_param, user_residual_md, handler, + &pipeline); + } else { + output->ShareDataWith(*residual_param); + dst_memory_p = platform::SetDstMemory(ctx, output, handler); + } + need_s8_to_u8 = + (platform::MKLDNNGetDataType() == memory::data_type::s8) && + unsigned_output; + } else { + dst_memory_p = platform::SetDstMemory(ctx, output, handler); + } - user_residual_memory_p = handler.AcquireResidualDataMemory( - user_residual_md, to_void_cast(residual_param_data)); + // create convolution op primitive + auto scale_bias_key = key + "@scale_bias"; + if (bias) { + const K* bias_data = bias->data(); + auto user_bias_md = platform::MKLDNNMemDesc( + {bias_tz}, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::x); + auto user_bias_memory_p = handler->AcquireBiasMemory( + user_bias_md, to_void_cast(bias_data)); + std::shared_ptr bias_memory_p; + int mask_reorder = is_multi_channel ? 1 << 0 : 1; + int count = + is_multi_channel + ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0]) + : 1; + std::vector scale_bias_data(count); +#pragma omp parallel for if (count > 1) + for (int i = 0; i < count; i++) { + scale_bias_data[i] = scale_in_data * scale_weights_data[i]; + } + bias_memory_p = handler->AcquireBiasMemoryFromPrimitive( + user_bias_memory_p, pipeline, is_test, true, scale_bias_data, + mask_reorder); + conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, + bias_memory_p, dst_memory_p); + } else { + conv_p = handler->AcquireConvolution(src_memory_p, weights_memory_p, + dst_memory_p); + } + // push primitive to stream and wait until it's executed + pipeline.push_back(*conv_p); + } else { + auto src_memory_reorder_p = std::static_pointer_cast( + dev_ctx.GetBlob(src_reorder_key)); + src_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(src_key)); + if (src_memory_reorder_p) { + user_src_memory_p = std::static_pointer_cast( + dev_ctx.GetBlob(user_src_key)); + user_src_memory_p->set_data_handle(to_void_cast(input_data)); + } else if (src_memory_p) { + src_memory_p->set_data_handle(to_void_cast(input_data)); + } - T_out* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( - user_residual_memory_p, to_void_cast(output_data), pipeline); + dst_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(dst_key)); + conv_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_conv_pd)); + if (conv_pd) { + handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, + mkldnn_engine, key)); + } - } else { + if (fuse_residual_conn) { + auto residual_param = ctx.Input("ResidualData"); output->ShareDataWith(*residual_param); - auto output_data = output->mutable_data(ctx.GetPlace()); - dst_memory_p = handler.AcquireDstMemoryFromPrimitive( - to_void_cast(output_data)); + need_s8_to_u8 = + (platform::MKLDNNGetDataType() == memory::data_type::s8) && + unsigned_output; } - } else { - T_out* output_data = output->mutable_data( - ctx.GetPlace(), handler.GetDstMemorySize()); - dst_memory_p = handler.AcquireDstMemoryFromPrimitive( - to_void_cast(output_data)); - } + platform::SetDstMemoryHandler(ctx, output, handler, dst_memory_p); - // create convolution op primitive - if (bias) { - const float* bias_data = bias->data(); - auto user_bias_md = platform::MKLDNNMemDesc( - {bias_tz}, platform::MKLDNNGetDataType(), memory::format::x); - auto user_bias_memory_p = handler.AcquireBiasMemory( - user_bias_md, to_void_cast(bias_data)); - std::shared_ptr bias_memory_p; - - auto scale_bias_data = - ComputeBiasScale(scale_in_data, scale_weights_data); - int mask_bias_reorder = ComputeBiasMask(is_multi_channel); - bias_memory_p = handler.AcquireBiasMemoryFromPrimitive( - user_bias_memory_p, pipeline, is_test, true, scale_bias_data, - mask_bias_reorder); - conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, - bias_memory_p, dst_memory_p); - } else { - conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, - dst_memory_p); - } - // push primitive to stream and wait until it's executed - pipeline.push_back(*conv_p); + if (src_memory_reorder_p) { + pipeline.push_back(*src_memory_reorder_p); + } + auto residual_reorder_p = std::static_pointer_cast( + dev_ctx.GetBlob(residual_reorder_key)); + if (residual_reorder_p) { + pipeline.push_back(*residual_reorder_p); + } + pipeline.push_back(*conv_p); + } // push primitive to stream and wait until it's executed stream(stream::kind::eager).submit(pipeline).wait(); - if (platform::MKLDNNGetDataType() == memory::data_type::s8 && - unsigned_output) { + if (need_s8_to_u8) { output->mutable_data(ctx.GetPlace()); } output->set_layout(DataLayout::kMKLDNN); @@ -649,7 +701,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto src_tz = paddle::framework::vectorize(input->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims()); int g = std::max(groups, 1); - GetWeightsTz(weights_tz, g); + GetWeightsTz(weights_tz, g, is_conv3d); auto dst_tz = paddle::framework::vectorize(output_grad->dims()); auto src_format = input->format(); MKLDNNMemoryFormat weights_format = @@ -704,8 +756,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto conv_pd = std::static_pointer_cast( dev_ctx.GetBlob(key_conv_pd)); - PADDLE_ENFORCE(conv_pd != nullptr, - "Fail to find conv_pd in device context"); + PADDLE_ENFORCE_NE(conv_pd, nullptr, + "Fail to find conv_pd in device context"); // create backward convolution weights primitive descriptor auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc( @@ -786,6 +838,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { stream(stream::kind::eager).submit(pipeline).wait(); } }; + } // namespace operators } // namespace paddle @@ -794,17 +847,17 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); + ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, U8, ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); + ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, S8, ops::kConvMKLDNNINT8, - ops::ConvMKLDNNOpKernel); + ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, @@ -814,7 +867,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, - ops::ConvMKLDNNOpKernel); + ops::ConvMKLDNNOpKernel); REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index aa0b3d7d33..a18228a689 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -816,6 +816,15 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { mkldnn::engine engine, const std::string& base_key) : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + // TODO(jczaja): remove after conv int8 is adapted + ConvMKLDNNTemplateHandler( + std::shared_ptr conv_pd, + const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, + const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) { + conv_pd_ = conv_pd; + } + ConvMKLDNNTemplateHandler( std::shared_ptr conv_pd, std::shared_ptr @@ -1136,6 +1145,47 @@ using ConvTransposeMKLDNNHandler = mkldnn::deconvolution_backward_data, mkldnn::deconvolution_backward_weights>; +template +static std::shared_ptr SetDstMemory( + const framework::ExecutionContext& ctx, framework::Tensor* output, + const std::shared_ptr& handler) { + T* output_data = + output->mutable_data(ctx.GetPlace(), handler->GetDstMemorySize()); + std::shared_ptr dst_memory_p = + handler->AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); + return dst_memory_p; +} + +template +static std::shared_ptr SetDstMemory( + const framework::ExecutionContext& ctx, framework::Tensor* output, + const framework::Tensor* residual_param, + const mkldnn::memory::desc& user_residual_md, + const std::shared_ptr& handler, + std::vector* pipeline) { + const T* residual_param_data = residual_param->data(); + PADDLE_ENFORCE(residual_param_data != nullptr, + "Provide data if you want MKLDNN conv+elementwise_add fusion"); + std::shared_ptr user_residual_memory_p = + handler->AcquireResidualDataMemory(user_residual_md, + to_void_cast(residual_param_data)); + T* output_data = output->mutable_data(ctx.GetPlace()); + std::shared_ptr dst_memory_p = + handler->AcquireDstMemoryFromResidualDataMemory( + user_residual_memory_p, to_void_cast(output_data), *pipeline); + return dst_memory_p; +} + +template +static void SetDstMemoryHandler( + const framework::ExecutionContext& ctx, framework::Tensor* output, + const std::shared_ptr& handler, + std::shared_ptr dst_memory_p) { + T* output_data = + output->mutable_data(ctx.GetPlace(), handler->GetDstMemorySize()); + dst_memory_p->set_data_handle(to_void_cast(output_data)); +} + template static void SetDstMemoryQuantized( const framework::ExecutionContext& ctx, framework::Tensor* output, -- GitLab