diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 65092e059f4afe0bff98cce3e63a520e2d736a6c..7b790a6081ed73a038238db7c66b06638fde4075 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -144,12 +144,6 @@ class ConvMKLDNNHandlerT bias->dims().size())); } - const std::string fuse_activation = - ctx.Attr("fuse_activation"); - const float fuse_alpha = ctx.Attr("fuse_alpha"); - const float fuse_beta = ctx.Attr("fuse_beta"); - const bool fuse_residual_conn = - ctx.Attr("fuse_residual_connection"); const int groups = ctx.Attr("groups"); const std::string padding_algorithm = ctx.Attr("padding_algorithm"); @@ -221,24 +215,7 @@ class ConvMKLDNNHandlerT const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference : dnnl::prop_kind::forward_training; - float sum_scale = 1.0f; - float activation_scale = 1.0f; - std::vector output_shift_scale; - if (platform::is_int8()) { - if (ctx.HasAttr("Sum_scale")) { - sum_scale = ctx.Attr("Sum_scale"); - activation_scale = ctx.Attr("Activation_scale"); - output_shift_scale = - ctx.Attr>("Output_shift_scale"); - } else { - std::tie(sum_scale, output_shift_scale, activation_scale) = - get_int8_scales(ctx); - } - } - - const dnnl::primitive_attr conv_attr = CreatePostOps( - fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, - output_shift_scale, sum_scale, activation_scale); // for INT8 only! + const dnnl::primitive_attr conv_attr = CreateConvAttrs(ctx); if (bias) { auto bias_tz = phi::vectorize(bias->dims()); @@ -460,12 +437,13 @@ class ConvMKLDNNHandlerT auto scale_weights_data = ctx.Attr>("Scale_weights"); bool is_multi_channel = scale_weights_data.size() > 1; bool has_activation = !ctx.Attr("fuse_activation").empty(); - float activation_scale = force_fp32_output ? 1.0f - : has_activation ? ctx.Attr("Scale_out") - : 1.0f; - auto scale_out_data = force_fp32_output ? 1.0f - : has_activation ? 1.0f - : ctx.Attr("Scale_out"); + float activation_scale = (!force_fp32_output && has_activation) + ? ctx.Attr("Scale_out") + : 1.0f; + + float scale_out_data = (force_fp32_output || has_activation) + ? 1.0f + : ctx.Attr("Scale_out"); float sum_scale = fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; int count = @@ -490,15 +468,33 @@ class ConvMKLDNNHandlerT return std::make_tuple(sum_scale, output_shift_scale, activation_scale); } - dnnl::primitive_attr CreatePostOps( - std::string fuse_activation, float fuse_alpha, float fuse_beta, - bool fuse_residual_conn, const std::vector output_shift_scale = {}, - float sum_scale = 1.0f, float activation_scale = 1.0f) { + dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) { dnnl::primitive_attr conv_attr; dnnl::post_ops post_operations; - if (output_shift_scale.size() > 0) { - int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; - conv_attr.set_output_scales(mask, output_shift_scale); + + const std::string fuse_activation = + ctx.Attr("fuse_activation"); + const float fuse_alpha = ctx.Attr("fuse_alpha"); + const float fuse_beta = ctx.Attr("fuse_beta"); + const bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + + float sum_scale = 1.0f; + float activation_scale = 1.0f; + std::vector output_shift_scale; + if (platform::is_int8()) { + if (ctx.HasAttr("Sum_scale")) { + sum_scale = ctx.Attr("Sum_scale"); + activation_scale = ctx.Attr("Activation_scale"); + output_shift_scale = ctx.Attr>("Output_shift_scale"); + } else { + std::tie(sum_scale, output_shift_scale, activation_scale) = + get_int8_scales(ctx); + } + + if (output_shift_scale.size() > 0) { + int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; + conv_attr.set_output_scales(mask, output_shift_scale); + } } // Fusion with Elementwise layer relies on adding a sum post-operation with diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 99b8d7d6ae385b3b3b88940a534d2f895d52eaa7..615c7299bed03b6fb87a0969a352a2d3e55c5c9e 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -139,10 +139,6 @@ class ConvTransposeMKLDNNHandlerT * the memory format preferred for best performance */ const auto chosen_memory_format = MKLDNNMemoryFormat::any; - const std::string fuse_activation = - ctx.Attr("fuse_activation"); - const float fuse_alpha = ctx.Attr("fuse_alpha"); - const float fuse_beta = ctx.Attr("fuse_beta"); auto data_type = dnnl::memory::data_type::f32; if (ctx.Attr("mkldnn_data_type") == "bfloat16" || @@ -156,8 +152,7 @@ class ConvTransposeMKLDNNHandlerT const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const dnnl::primitive_attr conv_trans_attr = - CreatePostOps(fuse_activation, fuse_alpha, fuse_beta); + const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx); auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference : dnnl::prop_kind::forward_training; if (bias) { @@ -176,12 +171,15 @@ class ConvTransposeMKLDNNHandlerT } } - dnnl::primitive_attr CreatePostOps(const std::string& fuse_activation, - const float& fuse_alpha, - const float& fuse_beta) { + dnnl::primitive_attr CreateConvAttrs(const framework::ExecutionContext& ctx) { dnnl::primitive_attr conv_attr; dnnl::post_ops post_operations; + const std::string fuse_activation = + ctx.Attr("fuse_activation"); + const float fuse_alpha = ctx.Attr("fuse_alpha"); + const float fuse_beta = ctx.Attr("fuse_beta"); + // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 5cbcad5d965a4760883c690bb05c0e7537d77313..590ffe4d0d41b63069a13f3f6ed44afd5a79913e 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -201,7 +201,7 @@ class FCPrimitiveFactory { CreateMemDescriptor(weight_dims, MKLDNNMemoryFormat::any); auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); - const auto attrs = CreatePostOps(ctx); + const auto attrs = CreateFCAttrs(ctx); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); } @@ -230,7 +230,7 @@ class FCPrimitiveFactory { auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]}; auto dst_desc = CreateMemDescriptor(dst_dims, MKLDNNMemoryFormat::any); - const auto attrs = CreatePostOps(ctx); + const auto attrs = CreateFCAttrs(ctx); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); } @@ -255,7 +255,7 @@ class FCPrimitiveFactory { auto weights_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::any); auto bias_desc = CreateMemDescriptor(bias, MKLDNNMemoryFormat::x); auto dst_desc = CreateMemDescriptor(output, MKLDNNMemoryFormat::any); - const auto attrs = CreatePostOps(ctx); + const auto attrs = CreateFCAttrs(ctx); return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs); } @@ -455,8 +455,7 @@ class FCPrimitiveFactory { bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); } - // Fuse relu into FC with activation type attribute has been set to 'relu' - dnnl::primitive_attr CreatePostOps(const ExecutionContext& ctx) { + dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { dnnl::primitive_attr attributes; dnnl::post_ops post_operations; @@ -465,8 +464,8 @@ class FCPrimitiveFactory { std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); int mask = CreateMask(1, output_shift_scale.size() > 1); attributes.set_output_scales(mask, output_shift_scale); - float sum_scale = 1.0f; + float sum_scale = 1.0f; if (ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection")) { post_operations.append_sum(sum_scale); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 8921db6cbcef98b7fcf325a3a8e9c61e0e67df74..12867a482c79fee6265d416f1e11d80936caa960 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -147,16 +147,10 @@ class MatMulMKLDNNHandler this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); } // Constructor for FWD MatMul - MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx, - float scale) + MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext& ctx) : paddle::platform::MKLDNNHandlerNoCachingT( engine, ctx.GetPlace()) { - dnnl::primitive_attr attr; - float scale_out = ComputeOutputScale(ctx); - if (scale_out != 1.0f) { - constexpr unsigned tensor_wide_scale = 0; - attr.set_output_scales(tensor_wide_scale, {scale_out}); - } + const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); auto matmul_dims_ = GetMatmulDims(ctx); auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType(), @@ -165,7 +159,7 @@ class MatMulMKLDNNHandler matmul_dims_.y_strides); auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType(), matmul_dims_.out_strides); - this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md); + this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); } std::shared_ptr AcquireWeightsMemory(const Tensor* input) { @@ -429,6 +423,19 @@ class MatMulMKLDNNHandler return std::make_tuple(x_offset_, y_offset_, out_offset_); } + dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext& ctx) { + dnnl::primitive_attr matmul_attrs; + dnnl::post_ops post_operations; + + float scale_out = ComputeOutputScale(ctx); + if (scale_out != 1.0f) { + matmul_attrs.set_output_scales(0, {scale_out}); + } + + matmul_attrs.set_post_ops(post_operations); + return matmul_attrs; + } + private: uint32_t x_offset_; uint32_t y_offset_; @@ -499,23 +506,19 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* out = ctx.Output("Out"); - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; const auto& dev_ctx = ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { - MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) - .Execute(x, y, out); + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); } else if (is_bfloat16) { - MatMulMKLDNNHandler(dev_ctx.GetEngine(), - ctx, alpha) + MatMulMKLDNNHandler(onednn_engine, ctx) .Execute(x, y, out); } else if (fuse_relu) { - MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) - .Execute(x, y, out); + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); } else { - MatMulMKLDNNHandler(dev_ctx.GetEngine(), ctx, alpha) - .Execute(x, y, out); + MatMulMKLDNNHandler(onednn_engine, ctx).Execute(x, y, out); } }