未验证 提交 2c444dfa 编写于 作者: S Sławomir Siwek 提交者: GitHub

Replace matmul with matmul_v2 during oneDNN fuse passes (#49108)

* replace matmul with matmul_v2 in fuse passes

* Remove fusion logic from matmul

* removing fusion methods

* add proper name

* adjust namespaces
上级 958b9f07
...@@ -77,6 +77,16 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( ...@@ -77,6 +77,16 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
? "gelu_tanh" ? "gelu_tanh"
: "gelu_erf"; : "gelu_erf";
} }
if (matmul_type == "matmul") {
matmul_op->SetType("matmul_v2");
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_op->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_op->SetAttr("alpha", matmul_alpha);
}
}
matmul_op->SetAttr("fuse_activation", act_type); matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()}); matmul_op->SetOutput("Out", {activation_out->Name()});
......
...@@ -65,6 +65,16 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd( ...@@ -65,6 +65,16 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return; return;
} }
if (matmul_type == "matmul") {
matmul->Op()->SetType("matmul_v2");
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
auto matmul_alpha = matmul->Op()->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul->Op()->SetAttr("alpha", matmul_alpha);
}
}
matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()}); matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()}); matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});
......
...@@ -84,6 +84,15 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse( ...@@ -84,6 +84,15 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
} }
OpDesc *matmul_desc = matmul_op->Op(); OpDesc *matmul_desc = matmul_op->Op();
if (matmul_type == "matmul") {
matmul_desc->SetType("matmul_v2");
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_desc->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_desc->SetAttr("alpha", matmul_alpha);
}
}
matmul_desc->SetOutput("Out", {reshape_out->Name()}); matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape); matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis); matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
......
...@@ -85,6 +85,17 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph, ...@@ -85,6 +85,17 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>()); scale = *(scale_tensor->data<float>());
} }
if (op_type == "matmul") {
operator_op->Op()->SetType("matmul_v2");
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
auto matmul_alpha = operator_op->Op()->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
operator_op->Op()->SetAttr("alpha", matmul_alpha);
}
}
operator_op->Op()->SetAttr("fused_output_scale", scale); operator_op->Op()->SetAttr("fused_output_scale", scale);
operator_op->Op()->SetOutput("Out", {scale_out->Name()}); operator_op->Op()->SetOutput("Out", {scale_out->Name()});
......
...@@ -123,6 +123,15 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -123,6 +123,15 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return; return;
} }
if (matmul_type == "matmul") {
matmul_desc->SetType("matmul_v2");
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_desc->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_desc->SetAttr("alpha", matmul_alpha);
}
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()}); matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape); matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name, matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
......
...@@ -97,7 +97,7 @@ void TestMain(const std::string& op_name, bool with_xshapes) { ...@@ -97,7 +97,7 @@ void TestMain(const std::string& op_name, bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after); EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op(); auto* matmul_op_desc = GetOpNodes(graph, "matmul_v2").at(0)->Op();
auto check = [&matmul_op_desc](std::string a) { auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a; std::string shape_str = "fused_reshape_" + a;
......
...@@ -345,26 +345,6 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -345,26 +345,6 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
}; };
framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(),
0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulDoubleGradKernel : public framework::OpKernel<T> { class MatMulDoubleGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -579,8 +559,8 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -579,8 +559,8 @@ class MatMulOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul"); OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul");
auto dim_x = GetDimForInput(*context, "X"); auto dim_x = context->GetInputDim("X");
auto dim_y = GetDimForInput(*context, "Y"); auto dim_y = context->GetInputDim("Y");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// (jczaja): For NHWC execution output shape needs // (jczaja): For NHWC execution output shape needs
...@@ -681,14 +661,6 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -681,14 +661,6 @@ class MatMulOp : public framework::OperatorWithKernel {
framework::DDim ddim_out = phi::make_ddim(dim_out); framework::DDim ddim_out = phi::make_ddim(dim_out);
#ifdef PADDLE_WITH_MKLDNN
auto shape = context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
context->SetOutputDim("Out", ddim_out); context->SetOutputDim("Out", ddim_out);
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
...@@ -21,13 +20,14 @@ namespace { ...@@ -21,13 +20,14 @@ namespace {
using dnnl::memory; using dnnl::memory;
using paddle::framework::ExecutionContext; using paddle::framework::ExecutionContext;
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
using phi::DenseTensor;
using phi::OneDNNContext; using phi::OneDNNContext;
using phi::vectorize; using phi::vectorize;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
static phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) { static DenseTensor FoldOuterDims(const DenseTensor &input) {
auto output = input; auto output = input;
auto in_dims = input.dims(); auto in_dims = input.dims();
if (in_dims.size() == 3) { if (in_dims.size() == 3) {
...@@ -40,14 +40,14 @@ static phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) { ...@@ -40,14 +40,14 @@ static phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) {
// (Warning: This requires transposing data and writes into new memory.) // (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename T> template <typename T>
static phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx, static DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
const phi::DenseTensor *input) { const DenseTensor *input) {
auto input_dims = vectorize(input->dims()); auto input_dims = vectorize(input->dims());
if (input_dims.size() != 3) { if (input_dims.size() != 3) {
return *input; return *input;
} }
phi::DenseTensor output; DenseTensor output;
output.Resize({input_dims[1], input_dims[0], input_dims[2]}); output.Resize({input_dims[1], input_dims[0], input_dims[2]});
auto output_dims = vectorize(output.dims()); auto output_dims = vectorize(output.dims());
...@@ -71,30 +71,15 @@ static phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx, ...@@ -71,30 +71,15 @@ static phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
return output; return output;
} }
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
class MatMulV2MKLDNNHandler class MatMulV1OneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> { : public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const ExecutionContext &ctx, MatMulV1OneDNNHandler(const ExecutionContext &ctx,
const dnnl::engine engine, const dnnl::engine engine,
paddle::platform::Place cpu_place, phi::Place cpu_place,
const std::vector<int64_t> &x_org_dims, const std::vector<int64_t> &x_org_dims,
bool trans_x, const std::vector<int64_t> &y_org_dims)
const std::vector<int64_t> &y_org_dims,
bool trans_y,
bool is_output_fused,
const std::vector<int64_t> &x_strides_override,
const std::vector<int64_t> &y_strides_override)
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine, : phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
// M X K * K X N // M X K * K X N
...@@ -105,6 +90,8 @@ class MatMulV2MKLDNNHandler ...@@ -105,6 +90,8 @@ class MatMulV2MKLDNNHandler
const int H_idx = x_dims.size() - 2; const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1; const int W_idx = x_dims.size() - 1;
auto trans_x = ctx.Attr<bool>("transpose_X");
auto trans_y = ctx.Attr<bool>("transpose_Y");
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]); if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]); if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
...@@ -121,24 +108,16 @@ class MatMulV2MKLDNNHandler ...@@ -121,24 +108,16 @@ class MatMulV2MKLDNNHandler
y_strides.reserve(x_dims.size()); y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) { if (trans_x) {
x_strides = x_strides_override; x_strides.insert(x_strides.end(), {M * K, 1, M});
} else { } else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1}); x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
} }
if (!y_strides_override.empty()) { if (trans_y) {
y_strides = y_strides_override; y_strides.insert(y_strides.end(), {N * K, 1, K});
} else { } else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1}); y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
} }
out_strides.insert(out_strides.end(), {M * N, N, 1}); out_strides.insert(out_strides.end(), {M * N, N, 1});
...@@ -147,20 +126,11 @@ class MatMulV2MKLDNNHandler ...@@ -147,20 +126,11 @@ class MatMulV2MKLDNNHandler
for (int i = x_dims.size() - 4; i >= 0; --i) { for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]); out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
// TODO(jczaja): Why not for int8??
if (!phi::funcs::is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = auto x_md =
memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides); memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides);
auto y_md = auto y_md =
...@@ -168,163 +138,24 @@ class MatMulV2MKLDNNHandler ...@@ -168,163 +138,24 @@ class MatMulV2MKLDNNHandler
auto out_md = memory::desc( auto out_md = memory::desc(
out_ddims, phi::funcs::OneDNNGetDataType<OT>(), out_strides); out_ddims, phi::funcs::OneDNNGetDataType<OT>(), out_strides);
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
void AppendActivation(const ExecutionContext &ctx,
dnnl::post_ops &post_ops, // NOLINT
float activation_scale = 1.0f) {
const auto invalid_attribute =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation").empty()
: true;
if (invalid_attribute) return;
const auto fuse_activation = ctx.Attr<std::string>("fuse_activation");
const auto fuse_alpha =
ctx.HasAttr("fuse_alpha") ? ctx.Attr<float>("fuse_alpha") : 0.0f;
const auto fuse_beta =
ctx.HasAttr("fuse_beta") ? ctx.Attr<float>("fuse_beta") : 0.0f;
if (fuse_activation == "hard_sigmoid") {
post_ops.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha,
fuse_beta);
post_ops.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else {
const std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto &activation_type = activation_map.find(fuse_activation);
PADDLE_ENFORCE_NE(
activation_type,
activation_map.end(),
phi::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
fuse_activation));
post_ops.append_eltwise(
activation_scale, activation_type->second, fuse_alpha, fuse_beta);
}
}
float ComputeOutputScale(const ExecutionContext &ctx) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") &&
ctx.HasAttr("Scale_out")) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) {
dnnl::primitive_attr matmul_attrs; dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(ctx); float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) { if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out}); matmul_attrs.set_output_scales(0, {scale_out});
} }
if (ctx.HasInput("ResidualData")) {
auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
phi::funcs::OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (ctx.HasAttr("Scale_in_eltwise")) {
float sum_scale = scale_out / ctx.Attr<float>("Scale_in_eltwise");
post_operations.append_sum(sum_scale);
}
}
AppendActivation(ctx, post_operations);
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
matmul_attrs.set_post_ops(post_operations); matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t> &matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const phi::DenseTensor *input) { this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor *output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence phi::DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of phi::DenseTensor as computed in ComputeInferShape
OT *ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
} }
};
template <typename XT, typename YT, typename OT> MatMulV1OneDNNHandler(const dnnl::engine engine,
class MatMulMKLDNNHandler phi::Place cpu_place,
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> { DenseTensor *x,
public:
MatMulMKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
phi::DenseTensor *x,
bool trans_x, bool trans_x,
phi::DenseTensor *y, DenseTensor *y,
bool trans_y, bool trans_y,
phi::DenseTensor *out, DenseTensor *out,
float scale) float scale)
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine, : phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
...@@ -344,10 +175,10 @@ class MatMulMKLDNNHandler ...@@ -344,10 +175,10 @@ class MatMulMKLDNNHandler
memory::dims out_dims = {out_bs, M, N}; memory::dims out_dims = {out_bs, M, N};
memory::dims x_strides = memory::dims x_strides =
!trans_x ? memory::dims{M * K, K, 1} : memory::dims{M * K, 1, M}; trans_x ? memory::dims{M * K, 1, M} : memory::dims{M * K, K, 1};
memory::dims y_strides = memory::dims y_strides =
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; trans_y ? memory::dims{N * K, 1, K} : memory::dims{N * K, N, 1};
memory::dims out_strides = memory::dims{M * N, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides); auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
...@@ -360,65 +191,41 @@ class MatMulMKLDNNHandler ...@@ -360,65 +191,41 @@ class MatMulMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
} }
std::shared_ptr<memory> AcquireWeightsMemory(const phi::DenseTensor *input) { float ComputeOutputScale(const ExecutionContext &ctx) {
float alpha = ctx.Attr<float>("alpha");
if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") &&
ctx.HasAttr("Scale_out")) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor *input) {
const YT *input_data = input->data<YT>(); const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<YT>(input_data)); phi::funcs::to_void_cast<YT>(input_data));
} }
public: std::shared_ptr<memory> AcquireDstMemory(DenseTensor *output) {
void Execute(const phi::DenseTensor *x,
const phi::DenseTensor *y,
phi::DenseTensor *out) {
const auto src_memory_p = this->AcquireSrcMemory(x);
const auto weights_memory_p = this->AcquireWeightsMemory(y);
const auto dst_memory_p = this->AcquireDstMemory(out);
auto matmul_p = this->AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto &astream = OneDNNContext::tls().get_stream();
// Simulate batch matmul by processing in loop
void *x_ptr = src_memory_p->get_data_handle();
void *y_ptr = weights_memory_p->get_data_handle();
void *out_ptr = dst_memory_p->get_data_handle();
auto offsets = std::make_tuple(x_offset_, y_offset_, out_offset_);
for (uint16_t i = 0; i < batch_size_; ++i) {
src_memory_p->set_data_handle(x_ptr);
weights_memory_p->set_data_handle(y_ptr);
dst_memory_p->set_data_handle(out_ptr);
matmul_p->execute(astream, matmul_args);
x_ptr = static_cast<char *>(x_ptr) + std::get<0>(offsets);
y_ptr = static_cast<char *>(y_ptr) + std::get<1>(offsets);
out_ptr = static_cast<char *>(out_ptr) + std::get<2>(offsets);
}
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor *output) {
// We cannot use base AcquireDstMemory as it makes an allocation request // We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul // base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift // we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence phi::DenseTensor size is bigger that // pointer for every new batch. Hence DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there // dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can // and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of phi::DenseTensor as computed in ComputeInferShape // leave default size of DenseTensor as computed in ComputeInferShape
OT *ptr = output->mutable_data<OT>(this->place_); OT *ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
} }
private: private:
uint32_t x_offset_;
uint32_t y_offset_;
uint32_t out_offset_;
uint16_t batch_size_; uint16_t batch_size_;
}; };
...@@ -429,7 +236,7 @@ class MatMulMKLDNNHandler ...@@ -429,7 +236,7 @@ class MatMulMKLDNNHandler
* If transposed, `H,W` will be swapped. * If transposed, `H,W` will be swapped.
*/ */
static void ReshapeTensorToMatrixSequence( static void ReshapeTensorToMatrixSequence(
phi::DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) { DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) {
int64_t h, w; int64_t h, w;
h = descriptor.height_; h = descriptor.height_;
w = descriptor.width_; w = descriptor.width_;
...@@ -457,9 +264,9 @@ static void ReshapeTensorToMatrixSequence( ...@@ -457,9 +264,9 @@ static void ReshapeTensorToMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the * If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize. * BatchSize.
*/ */
static void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x, static void ReshapeXYOutToMatrixSequence(DenseTensor *x,
phi::DenseTensor *y, DenseTensor *y,
phi::DenseTensor *out, DenseTensor *out,
bool trans_x, bool trans_x,
bool trans_y) { bool trans_y) {
auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims()); auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims());
...@@ -486,13 +293,13 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x, ...@@ -486,13 +293,13 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
auto axis_set = std::set<int>(axis.begin(), axis.end()); auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"In an axis array, elements must be unique.")); "In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, PADDLE_ENFORCE_EQ(
in_rank,
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument("The input dimension's size "
"The input dimension's size "
"should be equal to the axis's size. " "should be equal to the axis's size. "
"But received dimension is %d, " "But received dimension is %d, "
"axis's size is %d", "axis's size is %d",
...@@ -501,7 +308,7 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x, ...@@ -501,7 +308,7 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1).")); "Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size()); std::vector<int64_t> new_x(x.size());
...@@ -511,73 +318,16 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x, ...@@ -511,73 +318,16 @@ std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
return new_x; return new_x;
} }
std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
const std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector = input_name == "X"
? phi::funcs::RowMatrixDimsFromVector
: phi::funcs::ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims),
0,
ctx.HasAttr("trans_x")
? ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0])))
: ctx.Attr<bool>(std::string("transpose_") + input_name[0]));
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = Transpose(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
bool IsOutputFused(const ExecutionContext &ctx) {
auto &fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto &fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
template <typename T, typename T_out> template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx, void ExecuteMatMul(const ExecutionContext &ctx,
const dnnl::engine onednn_engine, const DenseTensor *x,
const phi::DenseTensor *x,
const std::vector<int64_t> &x_dims, const std::vector<int64_t> &x_dims,
bool trans_x, const DenseTensor *y,
const phi::DenseTensor *y,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
bool trans_y, DenseTensor *out) {
phi::DenseTensor *out) { const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X"); MatMulV1OneDNNHandler<T, T, T_out> handler(
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y"); ctx, dev_ctx.GetEngine(), ctx.GetPlace(), x_dims, y_dims);
MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -590,38 +340,23 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -590,38 +340,23 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
if (ctx.HasInput("ResidualData")) {
auto *residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto &astream = OneDNNContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(ctx) && !phi::funcs::is_int8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc( out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims()))); dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
} }
template <typename T> template <typename T>
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { class MatMulV1OneDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext &ctx) const override { void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), ctx.Attr<int>("head_number"),
1, 1,
paddle::platform::errors::Unimplemented( phi::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected " "oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
...@@ -633,19 +368,12 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -633,19 +368,12 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
: false; : false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
const auto &dev_ctx = ctx.template device_context<OneDNNContext>(); auto *x = ctx.Input<DenseTensor>("X");
const auto &onednn_engine = dev_ctx.GetEngine(); auto *y = ctx.Input<DenseTensor>("Y");
auto *out = ctx.Output<DenseTensor>("Out");
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y");
auto *out = ctx.Output<phi::DenseTensor>("Out");
bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x")
: ctx.Attr<bool>("transpose_X");
bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
: ctx.Attr<bool>("transpose_Y");
auto x_dims = vectorize(GetDimForInput(ctx, "X")); auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(GetDimForInput(ctx, "Y")); auto y_dims = vectorize(y->dims());
int ndims = std::max(x_dims.size(), y_dims.size()); int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
...@@ -653,58 +381,26 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -653,58 +381,26 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
std::vector<int64_t> x_bd_dims(ndims, 1); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1); std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, out); CalculateMatrixDims(x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
ExecuteMatMulV2<T, float>(ctx, ExecuteMatMul<T, float>(ctx, x, x_bd_dims, y, y_bd_dims, out);
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else if (is_bfloat16) { } else if (is_bfloat16) {
ExecuteMatMulV2<T, paddle::platform::bfloat16>(ctx, ExecuteMatMul<T, phi::dtype::bfloat16>(
onednn_engine, ctx, x, x_bd_dims, y, y_bd_dims, out);
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else if (fuse_relu) { } else if (fuse_relu) {
ExecuteMatMulV2<T, uint8_t>(ctx, ExecuteMatMul<T, uint8_t>(ctx, x, x_bd_dims, y, y_bd_dims, out);
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else { } else {
ExecuteMatMulV2<T, int8_t>(ctx, ExecuteMatMul<T, int8_t>(ctx, x, x_bd_dims, y, y_bd_dims, out);
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} }
} }
private: private:
void CalculateMatrixDims(const ExecutionContext &ctx, void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims, const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims, std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims, std::vector<int64_t> *y_bd_dims,
phi::DenseTensor *out) const { DenseTensor *out) const {
if (x_dims.size() == 1) { if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0]; (*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) { } else if (x_dims.size() == 2) {
...@@ -726,15 +422,15 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -726,15 +422,15 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
} }
if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) { if (x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims()); auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) { for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 || (*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1, (*y_bd_dims)[i] == 1,
true, true,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"phi::DenseTensor dimensions are incorrect for broadcasting." "DenseTensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but " "Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d", "received x_dim[%d]=%d and y_dims[%d]= %d",
i, i,
...@@ -749,14 +445,14 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -749,14 +445,14 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { class MatMulV1GradOneDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext &ctx) const override { void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), ctx.Attr<int>("head_number"),
1, 1,
paddle::platform::errors::Unimplemented( phi::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected " "oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
...@@ -765,25 +461,18 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -765,25 +461,18 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
const auto &dev_ctx = ctx.template device_context<OneDNNContext>(); const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<phi::DenseTensor>("X"); auto x = *ctx.Input<DenseTensor>("X");
auto y = *ctx.Input<phi::DenseTensor>("Y"); auto y = *ctx.Input<DenseTensor>("Y");
auto dout = auto dout = *ctx.Input<DenseTensor>(paddle::framework::GradVarName("Out"));
*ctx.Input<phi::DenseTensor>(paddle::framework::GradVarName("Out")); auto *dx = ctx.Output<DenseTensor>(paddle::framework::GradVarName("X"));
auto *dx = auto *dy = ctx.Output<DenseTensor>(paddle::framework::GradVarName("Y"));
ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("X"));
auto *dy = bool transpose_x = ctx.Attr<bool>("transpose_X");
ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("Y")); bool transpose_y = ctx.Attr<bool>("transpose_Y");
bool transpose_x = ctx.HasAttr("transpose_X")
? ctx.Attr<bool>("transpose_X")
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y")
? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
paddle::framework::DDim dx_dims; phi::DDim dx_dims;
if (dx) { if (dx) {
dx_dims = dx->dims(); dx_dims = dx->dims();
if (dx_dims != x.dims()) { if (dx_dims != x.dims()) {
...@@ -791,7 +480,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -791,7 +480,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
} }
paddle::framework::DDim dy_dims; phi::DDim dy_dims;
if (dy) { if (dy) {
dy_dims = dy->dims(); dy_dims = dy->dims();
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
...@@ -871,31 +560,31 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -871,31 +560,31 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
void ExecuteMatMulGrad(const ExecutionContext &ctx, void ExecuteMatMulGrad(const ExecutionContext &ctx,
const OneDNNContext &dev_ctx, const OneDNNContext &dev_ctx,
const dnnl::engine &engine, const dnnl::engine &engine,
phi::DenseTensor *x, DenseTensor *x,
bool trans_x, bool trans_x,
bool is_fold_init_dims_x, bool is_fold_init_dims_x,
phi::DenseTensor *y, DenseTensor *y,
bool trans_y, bool trans_y,
bool is_fold_init_dims_y, bool is_fold_init_dims_y,
phi::DenseTensor *out) const { DenseTensor *out) const {
// gradient is calculated in a different way when broadcasting is used // gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) && bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2; out->dims().size() == 2;
phi::DenseTensor x_combined, y_combined; DenseTensor x_combined, y_combined;
if (!need_combine) { if (need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x) x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x); : FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y) y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y); : FoldFirstAndLastDims<T>(dev_ctx, y);
} else {
x_combined = *x;
y_combined = *y;
} }
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f; float alpha = ctx.Attr<float>("alpha");
MatMulMKLDNNHandler<T, T, T> handler(engine, MatMulV1OneDNNHandler<T, T, T> handler(engine,
ctx.GetPlace(), ctx.GetPlace(),
&x_combined, &x_combined,
trans_x, trans_x,
...@@ -910,7 +599,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -910,7 +599,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto matmul_p = handler.AcquireForwardPrimitive(); auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = { std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
...@@ -929,13 +618,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -929,13 +618,13 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
REGISTER_OP_KERNEL(matmul, REGISTER_OP_KERNEL(matmul,
MKLDNN, MKLDNN,
::phi::CPUPlace, ::phi::CPUPlace,
MatMulMKLDNNKernel<float>, MatMulV1OneDNNKernel<float>,
MatMulMKLDNNKernel<paddle::platform::bfloat16>, MatMulV1OneDNNKernel<phi::dtype::bfloat16>,
MatMulMKLDNNKernel<int8_t>, MatMulV1OneDNNKernel<int8_t>,
MatMulMKLDNNKernel<uint8_t>); MatMulV1OneDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, REGISTER_OP_KERNEL(matmul_grad,
MKLDNN, MKLDNN,
::phi::CPUPlace, ::phi::CPUPlace,
MatMulGradMKLDNNKernel<float>, MatMulV1GradOneDNNKernel<float>,
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>); MatMulV1GradOneDNNKernel<phi::dtype::bfloat16>);
...@@ -99,7 +99,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -99,7 +99,7 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"fuse_alpha", ExtraAttrProperty::ONEDNN}, {"fuse_alpha", ExtraAttrProperty::ONEDNN},
{"fuse_beta", ExtraAttrProperty::ONEDNN}, {"fuse_beta", ExtraAttrProperty::ONEDNN},
{"fuse_relu", ExtraAttrProperty::ONEDNN}, {"fuse_relu", ExtraAttrProperty::ONEDNN},
{"fused_output_scale", ExtraAttrProperty::ONEDNN}, {"alpha", ExtraAttrProperty::ONEDNN},
{"fuse_residual_connection", ExtraAttrProperty::ONEDNN}, {"fuse_residual_connection", ExtraAttrProperty::ONEDNN},
{"fuse_with_relu", ExtraAttrProperty::ONEDNN}, {"fuse_with_relu", ExtraAttrProperty::ONEDNN},
{"fused_reshape_Out", ExtraAttrProperty::ONEDNN}, {"fused_reshape_Out", ExtraAttrProperty::ONEDNN},
......
...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): ...@@ -146,7 +146,7 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
'operator_scale_onednn_fuse_pass', 'operator_scale_onednn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): ...@@ -137,7 +137,7 @@ class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest):
'matmul_activation_mkldnn_fuse_pass', 'matmul_activation_mkldnn_fuse_pass',
], ],
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -76,7 +76,7 @@ class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest):
config = self.create_inference_config( config = self.create_inference_config(
use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass'] use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass']
) )
yield config, ['matmul'], (1e-5, 1e-5) yield config, ['matmul_v2'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -116,7 +116,7 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -116,7 +116,7 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul"], (1e-5, 1e-5) yield config, ["matmul_v2"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -135,17 +135,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -135,17 +135,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
# gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op
fused_op = "matmul_v2"
input1_dim1 = program_config.inputs["input_data1"].shape[0]
input2_dim1 = program_config.inputs["input_data2"].shape[0]
input1_dim2 = program_config.inputs["input_data1"].shape[1]
input2_dim2 = program_config.inputs["input_data2"].shape[1]
if input1_dim1 == input2_dim1 and input1_dim2 == input2_dim2:
fused_op = "matmul"
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, [fused_op], (1e-5, 1e-5) yield config, ["matmul_v2"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -153,7 +153,7 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -153,7 +153,7 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul"], (1e-5, 1e-5) yield config, ["matmul_v2"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import OpTest
class TestDnnlMatMulOp(OpTest): class TestDnnlMatMulOp(OpTest):
...@@ -254,321 +254,6 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): ...@@ -254,321 +254,6 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self.attrs = {'force_fp32_output': True} self.attrs = {'force_fp32_output': True}
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.")
class TestReshapeTransposeMatMulOp(OpTest):
def init_data_type(self):
self.data_type_ = 'float32'
def generate_data(self):
self.x = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.y = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.out = np.matmul(self.x, self.y.transpose([0, 1, 3, 2]))
self.fused_reshape_X = []
self.fused_transpose_X = []
self.fused_reshape_Y = []
self.fused_transpose_Y = []
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul"
self.transpose_y_name = "transpose_Y"
def setUp(self):
self.set_op_type_and_transpose_y_name()
self._cpu_only = True
self.use_mkldnn = True
self.transpose_y = True
self.init_data_type()
self.generate_data()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
self.transpose_y_name: self.transpose_y,
}
if len(self.fused_transpose_X) > 0:
self.attrs['fused_transpose_X'] = self.fused_transpose_X
if len(self.fused_transpose_Y) > 0:
self.attrs['fused_transpose_Y'] = self.fused_transpose_Y
if len(self.fused_reshape_X) > 0:
self.attrs['fused_reshape_X'] = self.fused_reshape_X
if len(self.fused_reshape_Y) > 0:
self.attrs['fused_reshape_Y'] = self.fused_reshape_Y
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
class TestReshapeTransposeMatMulOp4DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.transpose([0, 1, 3, 2]),
)
class TestReshapeTransposeMatMulOp4DXInt8(TestReshapeTransposeMatMulOp4DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp4DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 128, 768])
.astype("float32")
.reshape([2, 128, 12, 64])
.transpose([0, 2, 1, 3])
)
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])
)
class TestReshapeTransposeMatMulOp4DYInt8(TestReshapeTransposeMatMulOp4DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp4DXYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]),
)
class TestReshapeTransposeMatMulOp4DXYInt8(
TestReshapeTransposeMatMulOp4DXYFloat
):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp2DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32")
self.y = (
np.random.random([2, 5, 10])
.astype("float32")
.reshape([10, 10])
.transpose([1, 0])
)
self.fused_transpose_X = [1, 0]
self.fused_reshape_X = [10, 10]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([10, 10]).transpose([1, 0]), self.y.transpose([1, 0])
)
class TestReshapeTransposeMatMulOp2DXInt8(TestReshapeTransposeMatMulOp2DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp2DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 5, 10])
.astype("float32")
.reshape([10, 10])
.transpose([1, 0])
)
self.y = np.random.random([2, 5, 10]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [1, 0]
self.fused_reshape_Y = [10, 10]
self.out = np.matmul(self.x, self.y.reshape([10, 10]))
class TestReshapeTransposeMatMulOp2DYInt8(TestReshapeTransposeMatMulOp2DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp3DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype("float32")
self.y = (
np.random.random([2, 2, 5, 5])
.astype("float32")
.reshape([2, 10, 5])
.transpose([0, 2, 1])
)
self.fused_transpose_X = [0, 2, 1]
self.fused_reshape_X = [2, 10, 5]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 10, 5]).transpose(0, 2, 1),
self.y.transpose(0, 2, 1),
)
class TestReshapeTransposeMatMulOp3DXInt8(TestReshapeTransposeMatMulOp3DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestReshapeTransposeMatMulOp3DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = (
np.random.random([2, 2, 5, 5])
.astype(self.data_type_)
.reshape([2, 10, 5])
.transpose([0, 2, 1])
)
self.y = np.random.random([2, 2, 5, 5]).astype(self.data_type_)
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1]
self.fused_reshape_Y = [2, 10, 5]
self.out = np.matmul(self.x, self.y.reshape([2, 10, 5]))
class TestReshapeTransposeMatMulOp3DYInt8(TestReshapeTransposeMatMulOp3DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
@skip_check_grad_ci(reason="Tests inference only optimization.")
class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
def init_data_type(self):
self.data_type_ = np.float32
def generate_data(self):
self.bs = 1
self.x = np.random.random([self.bs, 128, 128]).astype(self.data_type_)
self.y = np.random.random([self.bs, 128, 64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = []
self.reshape_out = []
self.out = np.matmul(self.x, self.y)
def set_op_type(self):
self.op_type = "matmul"
def setUp(self):
self.set_op_type()
self._cpu_only = True
self.use_mkldnn = True
self.init_data_type()
self.generate_data()
self.init_params_and_out()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
if len(self.reshape_out) > 0:
self.attrs['fused_reshape_Out'] = self.reshape_out
if len(self.transpose_out) > 0:
self.attrs['fused_transpose_Out'] = self.transpose_out
self.inputs = {'X': self.x, 'Y': self.y}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def check_raise_error(self, msg):
try:
self.check_output()
except Exception as e:
if msg in str(e):
raise AttributeError
else:
print(e)
class TestMatMulOpTransposeReshapeIntEmptyInt(
TestMatMulOpTransposeReshapeEmptyFloat
):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat
):
def generate_data(self):
self.bs = 8
self.x = np.random.random([self.bs, 12, 128, 128]).astype(
self.data_type_
)
self.y = np.random.random([self.bs, 12, 128, 64]).astype(
self.data_type_
)
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = (
np.matmul(self.x, self.y)
.transpose([0, 2, 1, 3])
.reshape([self.bs, -1, self.x.shape[1] * self.y.shape[-1]])
)
class TestMatMulOpTransposeReshapeBasicInt(
TestMatMulOpTransposeReshapeBasicFloat
):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeBasicFloat
):
def generate_data(self):
self.bs = 11
self.x = np.random.random([self.bs, 12, 14, 18]).astype(self.data_type_)
self.y = np.random.random([self.bs, 12, 18, 13]).astype(self.data_type_)
class TestMatMulOpTransposeReshapeOtherDimInt(
TestMatMulOpTransposeReshapeOtherDimFloat
):
def init_data_type(self):
self.data_type_ = np.int8
if __name__ == "__main__": if __name__ == "__main__":
from paddle import enable_static from paddle import enable_static
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册