未验证 提交 f41ccbd5 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate matmul kernel (#48162)

* cleanup unused code

* unify is_int8 is_bfloat16

* Simplify matmul_v2 FWD kernel

* remove RunKernel methods

* remove import namespace

* remove headers

* clean fluid/phi cross imports

* remove fluid axpy_handler

* delete fluid methods

* activations

* OneDNNMemDesc

* MKLDNNFormatForSize

* MatchShapeToLayout

* MKLDNNMemoryFormat

* MKLDNNFormat

* ReorderMKLDNNHandler

* to_void_cast

* review suggestions

* interpolate

* remove fluid depedency

* init

* ExecuteMatMulV2

* rm fluid kernel

* matmul_grad

* remove mutable_data

* mul_grad

* matmul fwd

* add extra attr

* temp disable passes

* re-enable passes

* workaround for matmul+act

* fix for matmul+eltwise_add

* fix typo

* merge bugfix #48364

* remove merge conflict
上级 c928a35e
...@@ -381,7 +381,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -381,7 +381,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
} }
template <typename T> template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext &ctx) const override { void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
...@@ -696,21 +696,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -696,21 +696,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(matmul, REGISTER_OP_KERNEL(matmul,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>, MatMulMKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>, MatMulMKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>, MatMulMKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>); MatMulMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, REGISTER_OP_KERNEL(matmul_grad,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
MatMulGradMKLDNNKernel<float>, MatMulGradMKLDNNKernel<float>,
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>); MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
...@@ -98,6 +98,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -98,6 +98,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fuse_alpha", ExtraAttrProperty::ONEDNN}, {"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN}, {"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN}, {"fuse_relu", ExtraAttrProperty::ONEDNN},
{"fused_output_scale", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN}, {"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN}, {"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN}, {"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
...@@ -221,7 +222,8 @@ class ExtraInfoUtils { ...@@ -221,7 +222,8 @@ class ExtraInfoUtils {
std::unordered_map<std::string, std::vector<std::string>> std::unordered_map<std::string, std::vector<std::string>>
g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}}, g_extra_input_names_map_ = {{"conv2d", {"Bias", "ResidualData"}},
{"conv2d_transpose", {"Bias"}}, {"conv2d_transpose", {"Bias"}},
{"conv2d_grad", {"Bias"}}}; {"conv2d_grad", {"Bias"}},
{"matmul_v2", {"ResidualData"}}};
std::vector<std::string> empty_extra_input_names_; std::vector<std::string> empty_extra_input_names_;
}; };
......
...@@ -1874,9 +1874,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> { ...@@ -1874,9 +1874,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
if (scale_out != 1.0f) { if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out}); matmul_attrs.set_output_scales(0, {scale_out});
} }
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;
if (dev_ctx.HasDnnInput("ResidualData")) { if (residual_data) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
auto residual_data_tz = vectorize(residual_data->dims()); auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz, auto residual_data_md = memory::desc(residual_data_tz,
OneDNNGetDataType<OT>(), OneDNNGetDataType<OT>(),
...@@ -1893,9 +1895,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> { ...@@ -1893,9 +1895,11 @@ class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
AppendActivation(dev_ctx, post_operations); AppendActivation(dev_ctx, post_operations);
if (dev_ctx.HasDnnAttr("fused_output_scale")) { const float scale_alpha =
float scale_alpha = dev_ctx.HasDnnAttr("fused_output_scale")
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale")); ? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"))
: 1.0f;
if (scale_alpha != 1.0f) {
post_operations.append_eltwise( post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f); 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
} }
...@@ -2014,8 +2018,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx, ...@@ -2014,8 +2018,11 @@ void ExecuteMatmul(const OneDNNContext& dev_ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (dev_ctx.HasDnnInput("ResidualData")) { const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
auto* residual_data = dev_ctx.GetDnnInput("ResidualData"); ? dev_ctx.GetDnnInput("ResidualData")
: nullptr;
if (residual_data) {
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data); const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p}); *residual_data_memory_p});
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
DDim GetDimsForInput(const OneDNNContext &dev_ctx,
DDim input_dims,
std::string input_name) {
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
DenseTensor *out,
const bool is_output_fused) {
if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[1];
(*x_bd_dims)[(*x_bd_dims).size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[(*x_bd_dims).size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[(*x_bd_dims).size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[(*y_bd_dims).size() - 1] = y_dims[1];
(*y_bd_dims)[(*y_bd_dims).size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[(*y_bd_dims).size() - y_dims.size() + i] = y_dims[i];
}
}
if (!is_output_fused && x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
errors::InvalidArgument(
"Tensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
i,
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(make_ddim((out_dims)));
}
}
template <typename T, typename Context>
void MatmulKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
bool transpose_x,
bool transpose_y,
DenseTensor *out) {
if (dev_ctx.HasDnnAttr("head_number")) {
const auto head_number =
PADDLE_GET_CONST(int, dev_ctx.GetDnnAttr("head_number"));
PADDLE_ENFORCE_EQ(
head_number,
1,
errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
head_number));
}
constexpr bool is_int8 = funcs::is_int8<T>();
constexpr bool is_bfloat16 = funcs::is_bfloat16<T>();
const bool force_fp32_output =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;
bool fuse_relu = false;
if (dev_ctx.HasDnnAttr("fuse_activation")) {
auto act_type =
PADDLE_GET_CONST(std::string, dev_ctx.GetDnnAttr("fuse_activation"));
if (act_type == "relu" || act_type == "relu6") {
fuse_relu = true;
}
}
auto x_dims = vectorize(GetDimsForInput(dev_ctx, x.dims(), "X"));
auto y_dims = vectorize(GetDimsForInput(dev_ctx, y.dims(), "Y"));
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims,
y_dims,
&x_bd_dims,
&y_bd_dims,
out,
funcs::IsOutputFused(dev_ctx));
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
funcs::ExecuteMatmul<T, float>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (is_bfloat16) {
funcs::ExecuteMatmul<T, paddle::platform::bfloat16>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else if (fuse_relu) {
funcs::ExecuteMatmul<T, uint8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
} else {
funcs::ExecuteMatmul<T, int8_t>(
dev_ctx, x, y, x_bd_dims, y_bd_dims, transpose_x, transpose_y, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(matmul,
OneDNN,
ONEDNN,
phi::MatmulKernel,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册