From f3c14762ed64d1acb2a8ca8a2d05e1441d45f33b Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Fri, 5 Aug 2022 11:18:10 +0200 Subject: [PATCH] Add int8 support for matmulV2 (#44908) --- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 328 +++++++++--------- .../fluid/operators/mkldnn/mul_mkldnn_op.cc | 20 +- paddle/fluid/platform/mkldnn_reuse.h | 70 +++- 3 files changed, 231 insertions(+), 187 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 02632673b9..2f9fa210e2 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) { return alpha * scale_out / (scale_x * scale_y); } -template +template void ExecuteMatMulV2(const ExecutionContext &ctx, const MKLDNNDeviceContext &dev_ctx, const dnnl::engine onednn_engine, @@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, int execution_number = 0) { std::vector x_strides_override = GetInputStrides(ctx, "X"); std::vector y_strides_override = GetInputStrides(ctx, "Y"); - MatMulV2MKLDNNHandler handler(ctx, - onednn_engine, - ctx.GetPlace(), - x_dims, - trans_x, - y_dims, - trans_y, - IsOutputFused(ctx), - x_strides_override, - y_strides_override); + MatMulV2MKLDNNHandler handler(ctx, + onednn_engine, + ctx.GetPlace(), + x_dims, + trans_x, + y_dims, + trans_y, + IsOutputFused(ctx), + x_strides_override, + y_strides_override); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); @@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, auto &astream = MKLDNNDeviceContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); - - auto format = paddle::platform::MKLDNNFormatForSize( - out->dims().size(), dnnl::memory::format_tag::nchw); - out->set_layout(paddle::framework::DataLayout::kMKLDNN); + auto format = + MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw); out->set_format(format); + out->set_layout(DataLayout::kMKLDNN); } template class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { public: - void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } + void Compute(const ExecutionContext &ctx) const override { + if (ctx.HasAttr("head_number")) { + PADDLE_ENFORCE_EQ( + ctx.Attr("head_number"), + 1, + paddle::platform::errors::Unimplemented( + "oneDNN matmul doesn't support multiple heads. Expected " + "head_number=1. But received `head_number` is %d", + ctx.Attr("head_number"))); + } + constexpr bool is_int8 = IsInt8(); + constexpr bool is_bfloat16 = IsBfloat16(); + const bool force_fp32_output = ctx.HasAttr("force_fp32_output") + ? ctx.Attr("force_fp32_output") + : false; + constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses + if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { + RunKernel(ctx); + } else if (is_bfloat16) { + RunKernel(ctx); + } else if (fuse_relu) { + RunKernel(ctx); + } else { + RunKernel(ctx); + } + } private: void CalculateMatrixDims(const ExecutionContext &ctx, @@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { } } + template void RunKernel(const ExecutionContext &ctx) const { const auto &dev_ctx = ctx.template device_context(); const auto &onednn_engine = dev_ctx.GetEngine(); @@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { CalculateMatrixDims( ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out); - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - x, - x_bd_dims, - trans_x, - y, - y_bd_dims, - trans_y, - out, - out_dims); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + x, + x_bd_dims, + trans_x, + y, + y_bd_dims, + trans_y, + out, + out_dims); } }; @@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); if (trans_x && trans_y) { - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - y, - y_dims, - true, - dout, - dout_dims, - true, - &dx_tmp, - dx_bd_dims, - 1); - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - dout, - dout_dims, - true, - x, - x_dims, - true, - &dy_tmp, - dy_bd_dims, - 2); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + y, + y_dims, + true, + dout, + dout_dims, + true, + &dx_tmp, + dx_bd_dims, + 1); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + dout, + dout_dims, + true, + x, + x_dims, + true, + &dy_tmp, + dy_bd_dims, + 2); } else if (trans_x) { - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - y, - y_dims, - false, - dout, - dout_dims, - true, - &dx_tmp, - dx_bd_dims, - 1); - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - x, - x_dims, - false, - dout, - dout_dims, - false, - &dy_tmp, - dy_bd_dims, - 2); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + y, + y_dims, + false, + dout, + dout_dims, + true, + &dx_tmp, + dx_bd_dims, + 1); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + x, + x_dims, + false, + dout, + dout_dims, + false, + &dy_tmp, + dy_bd_dims, + 2); } else if (trans_y) { - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - dout, - dout_dims, - false, - y, - y_dims, - false, - &dx_tmp, - dx_bd_dims, - 1); - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - dout, - dout_dims, - true, - x, - x_dims, - false, - &dy_tmp, - dy_bd_dims, - 2); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + dout, + dout_dims, + false, + y, + y_dims, + false, + &dx_tmp, + dx_bd_dims, + 1); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + dout, + dout_dims, + true, + x, + x_dims, + false, + &dy_tmp, + dy_bd_dims, + 2); } else { - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - dout, - dout_dims, - false, - y, - y_dims, - true, - &dx_tmp, - dx_bd_dims, - 1); - ExecuteMatMulV2(ctx, - dev_ctx, - onednn_engine, - ctx.GetPlace(), - x, - x_dims, - true, - dout, - dout_dims, - false, - &dy_tmp, - dy_bd_dims, - 2); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + dout, + dout_dims, + false, + y, + y_dims, + true, + &dx_tmp, + dx_bd_dims, + 1); + ExecuteMatMulV2(ctx, + dev_ctx, + onednn_engine, + ctx.GetPlace(), + x, + x_dims, + true, + dout, + dout_dims, + false, + &dy_tmp, + dy_bd_dims, + 2); } if (x_dims != dx_bd_dims) { @@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel; namespace ops = paddle::operators; -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, - MKLDNN, - ::paddle::platform::CPUPlace, - S8, - 0, - MatMulMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, - MKLDNN, - ::paddle::platform::CPUPlace, - U8, - 0, - MatMulMKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, - MKLDNN, - ::paddle::platform::CPUPlace, - FP32, - 0, - MatMulV2MKLDNNKernel); - -REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( - matmul, - MKLDNN, - ::paddle::platform::CPUPlace, - BF16, - 0, - MatMulV2MKLDNNKernel); +REGISTER_OP_KERNEL(matmul, + MKLDNN, + ::paddle::platform::CPUPlace, + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel); REGISTER_OP_KERNEL(matmul_grad, MKLDNN, @@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, MatMulV2MKLDNNKernel, - MatMulV2MKLDNNKernel); + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel, + MatMulV2MKLDNNKernel); REGISTER_OP_KERNEL(matmul_v2_grad, MKLDNN, diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 956dbc810f..e727a4fe9f 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel { bool trans_y, Tensor *out) const { static const std::vector vec_placeholder; - MatMulV2MKLDNNHandler handler(ctx, - onednn_engine, - ctx.GetPlace(), - x_dims, - trans_x, - y_dims, - trans_y, - false, - vec_placeholder, - vec_placeholder); + MatMulV2MKLDNNHandler handler(ctx, + onednn_engine, + ctx.GetPlace(), + x_dims, + trans_x, + y_dims, + trans_y, + false, + vec_placeholder, + vec_placeholder); const auto src_memory_p = handler.AcquireSrcMemory(x); const auto weights_memory_p = handler.AcquireWeightsMemory(y); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 6c802b682e..1c3208e441 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -860,8 +860,18 @@ class ReductionMKLDNNHandler }; template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} + +template +constexpr bool IsBfloat16() { + return std::is_same::value; +} + +template class MatMulV2MKLDNNHandler - : public paddle::platform::MKLDNNHandlerNoCachingT { + : public paddle::platform::MKLDNNHandlerNoCachingT { public: MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx, const dnnl::engine engine, @@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler bool is_output_fused, const std::vector& x_strides_override, const std::vector& y_strides_override) - : paddle::platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { + : paddle::platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { // M X K * K X N std::vector x_dims(x_org_dims); std::vector y_dims(y_org_dims); @@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; } - if (is_output_fused) { + if (!IsInt8() && !IsBfloat16() && is_output_fused) { out_strides = FakeTransposeStrides(out_ddims); } - auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); - auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); - auto out_md = memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); + auto x_md = memory::desc(x_dims, MKLDNNGetDataType(), x_strides); + auto y_md = memory::desc(y_dims, MKLDNNGetDataType(), y_strides); + auto out_md = memory::desc(out_ddims, MKLDNNGetDataType(), out_strides); const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); } - // TODO(jczaja) : Adapt to int8 + float ComputeOutputScale(const framework::ExecutionContext& ctx) { + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") && + ctx.HasAttr("Scale_out")) { + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + bool force_fp32_out = ctx.HasAttr("force_fp32_output") + ? ctx.Attr("force_fp32_output") + : false; + float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); + alpha *= scale_out / (scale_x * scale_y); + } + return alpha; + } + dnnl::primitive_attr CreateMatmulAttrs( const framework::ExecutionContext& ctx) { dnnl::primitive_attr matmul_attrs; dnnl::post_ops post_operations; - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; - if (alpha != 1.0f) { - matmul_attrs.set_output_scales(0, {alpha}); + float scale_out = ComputeOutputScale(ctx); + if (scale_out != 1.0f) { + matmul_attrs.set_output_scales(0, {scale_out}); } if (ctx.HasInput("ResidualData")) { @@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler } std::shared_ptr AcquireWeightsMemory(const Tensor* input) { - const T* input_data = input->data(); + const YT* input_data = input->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), - to_void_cast(input_data)); + to_void_cast(input_data)); + } + + std::shared_ptr AcquireDstMemory( + paddle::framework::Tensor* output) { + // We cannot use base AcquireDstMemory as it makes an allocation request + // base on DST memory primitive size. This is fine in general, but in MatMul + // we have primitive that covers only one batch of Data and then shift + // pointer for every new batch. Hence Tensor size is bigger that dst memory + // primitive size. So would we request less memory that is there and it + // triggers an + // assertion. So as there is no 'any' format here we can leave default size + // of Tensor as computed in ComputeInferShape + OT* ptr = output->mutable_data(this->place_); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); } }; @@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler static std::unordered_map GetAttributeMap( std::string act_type) { std::unordered_map attr_map; - if (act_type == "swish") + if (act_type == "swish") { attr_map.emplace("beta", "fuse_alpha"); - else if (act_type == "relu6") + } else if (act_type == "relu6") { attr_map.emplace("threshold", "fuse_alpha"); - else if (act_type == "hard_sigmoid") { + } else if (act_type == "hard_sigmoid") { attr_map.emplace("slope", "fuse_alpha"); attr_map.emplace("offset", "fuse_beta"); } else if (act_type == "clip") { -- GitLab