未验证 提交 ad80fbfe 编写于 作者: W Wang Xin 提交者: GitHub

static graph autogen code support for matmul op (#54338)

* static graph autogen code support for matmul op

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 87054fe3
......@@ -18,12 +18,22 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return phi::make_ddim({1, x_dim[0]});
}
class FusedFeedForwardOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......
......@@ -17,7 +17,6 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
......@@ -41,14 +40,141 @@ static std::vector<int64_t> GetInputShape(phi::DDim dim,
return phi::vectorize(dim);
}
class FusedMatmulOp : public MatMulV2Op {
class FusedMatmulOp : public framework::OperatorWithKernel {
public:
using MatMulV2Op::MatMulV2Op;
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x = phi::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> dims_y = phi::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(
ndims_x,
0,
phi::errors::InvalidArgument(
"The first input tensor X's dimension size must be greater than 0,"
" but received the first input tensor X's dimension size is 0. "));
PADDLE_ENFORCE_GT(
ndims_y,
0,
phi::errors::InvalidArgument(
"The second input tensor Y's dimension size must be greater than 0,"
" but received the second input tensor Y's dimension size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
ctx->ShareLoD("X", "Out");
};
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return phi::KernelKey(input_data_type, ctx.GetPlace());
};
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
#ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) {
return phi::KernelKey(tensor.place(),
phi::DataLayout::kNHWC,
expected_kernel_type.dtype());
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
};
};
class FusedMatmulOpMaker : public MatMulV2OpMaker {
class FusedMatmulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final {
AddInput("X", "tensor of shape (d0, d1 ... M, K)");
AddInput("Y", "tensor of shape (d0, d1 ... K, N)");
AddOutput("Out", "tensor of shape (d0, d1 ... M, N)");
AddAttr<bool>("trans_x",
"Set true to transpose the last two dimensions of X before "
"doing multiplication")
.SetDefault(false);
AddAttr<bool>("trans_y",
"Set true to transpose the last two dimensions of Y before "
"doing multiplication")
.SetDefault(false);
AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
In addition, it also follows the broadcast rule which is similar as
numpy.matmul.
)DOC");
Apply();
};
protected:
void Apply() override {
void Apply() {
AddInput("ResidualData",
"Extra input from matmul_elementwise_add_mkldnn_fuse_pass")
.AsDispensable()
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matmul_v2_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
namespace paddle {
namespace operators {
void MatMulV2Op::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "matmul_v2");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "matmul_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "matmul_v2");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x = phi::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> dims_y = phi::vectorize(ctx->GetInputDim("Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
ctx->SetOutputDim("Out", phi::make_ddim(new_dims));
ctx->ShareLoD("X", "Out");
}
phi::KernelKey MatMulV2Op::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey MatMulV2Op::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
#ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) {
return phi::KernelKey(
tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
void MatMulV2OpMaker::Make() {
AddInput("X", "tensor of shape (d0, d1 ... M, K)");
AddInput("Y", "tensor of shape (d0, d1 ... K, N)");
AddOutput("Out", "tensor of shape (d0, d1 ... M, N)");
AddAttr<bool>("trans_x",
"Set true to transpose the last two dimensions of X before "
"doing multiplication")
.SetDefault(false);
AddAttr<bool>("trans_y",
"Set true to transpose the last two dimensions of Y before "
"doing multiplication")
.SetDefault(false);
AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
In addition, it also follows the broadcast rule which is similar as
numpy.matmul.
)DOC");
Apply();
}
class MatMulV2OpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override {
if (framework::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
};
template <typename T>
class MatMulV2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul");
if (context->HasOutput("DX") && context->HasInput("DDY")) {
context->ShareDim("X", "DX");
}
if (context->HasOutput("DY") && context->HasInput("DDX")) {
context->ShareDim("Y", "DY");
}
if (context->HasOutput("DDOut") &&
(context->HasInput("DDY") || context->HasInput("DDX"))) {
context->ShareDim("DOut", "DDOut");
}
}
};
template <typename T>
class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
auto ddx = this->OutputGrad(framework::GradVarName("X"));
auto ddy = this->OutputGrad(framework::GradVarName("Y"));
if (!ddx.empty() || !ddy.empty()) {
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
op->SetOutput("DX",
ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X"));
op->SetOutput("DY",
ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
class MatMulCompositeDoubleGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
// get inputs
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor dout =
this->GetSingleForwardInput(framework::GradVarName("Out"));
paddle::optional<paddle::Tensor> ddx =
this->GetOptionalSingleOutputGrad(framework::GradVarName("X"));
paddle::optional<paddle::Tensor> ddy =
this->GetOptionalSingleOutputGrad(framework::GradVarName("Y"));
// get attr
bool trans_x = this->Attr<bool>("trans_x");
bool trans_y = this->Attr<bool>("trans_y");
// get output
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::Tensor y_grad_t = this->GetSingleInputGrad("Y");
paddle::Tensor grad_out_grad_t =
this->GetSingleInputGrad(framework::GradVarName("Out"));
// get output ptr
paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
paddle::Tensor* y_grad = this->GetOutputPtr(&y_grad_t);
paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t);
// get output orginal name
std::string x_grad_name = this->GetOutputName(x_grad_t);
std::string y_grad_name = this->GetOutputName(y_grad_t);
std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t);
VLOG(3) << "Runing matmul_double_grad composite func";
// call composite backward func
prim::matmul_double_grad<prim::DescTensor>(
x, y, dout, ddx, ddy, trans_x, trans_y, x_grad, y_grad, grad_out_grad);
// recover output name
this->RecoverOutputName(x_grad_t, x_grad_name);
this->RecoverOutputName(y_grad_t, y_grad_name);
this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name);
}
};
class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* context) const override {
OP_INOUT_CHECK(
context->HasInput("X"), "Input", "X", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("Y"), "Input", "Y", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("DOut"), "Input", "DOut", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("DDX"), "Input", "DDX", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("DDY"), "Input", "DDY", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("D_DX"), "Input", "D_DX", "matmul_v2_triple_grad");
OP_INOUT_CHECK(
context->HasInput("D_DY"), "Input", "D_DY", "matmul_v2_triple_grad");
OP_INOUT_CHECK(context->HasInput("D_DDOut"),
"Input",
"D_DDOut",
"matmul_v2_triple_grad");
if (context->HasOutput("D_X_out")) {
context->ShareDim("X", "D_X_out");
}
if (context->HasOutput("D_Y_out")) {
context->ShareDim("Y", "D_Y_out");
}
if (context->HasOutput("D_DOut_out")) {
context->ShareDim("DOut", "D_DOut_out");
}
if (context->HasOutput("D_DDX_out")) {
context->ShareDim("X", "D_DDX_out");
}
if (context->HasOutput("D_DDY_out")) {
context->ShareDim("Y", "D_DDY_out");
}
}
};
template <typename T>
class MatMulV2OpTripleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("matmul_v2_triple_grad");
// get input from double grad
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input("DOut"));
op->SetInput("DDX", this->Input("DDX"));
op->SetInput("DDY", this->Input("DDY"));
op->SetInput("D_DX", this->OutputGrad("DX"));
op->SetInput("D_DY", this->OutputGrad("DY"));
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
// set outputs
op->SetOutput("D_X_out", this->InputGrad("X"));
op->SetOutput("D_Y_out", this->InputGrad("Y"));
op->SetOutput("D_DOut_out", this->InputGrad("DOut"));
op->SetOutput("D_DDX_out", this->InputGrad("DDX"));
op->SetOutput("D_DDY_out", this->InputGrad("DDY"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(matmul_v2,
ops::MatMulV2Op,
ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::framework::OpDesc>,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DECLARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad,
MatMulV2GradInferShapeFunctor,
PD_INFER_META(phi::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad,
ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
ops::MatMulCompositeDoubleGradOpMaker,
MatMulV2GradInferShapeFunctor);
REGISTER_OPERATOR(matmul_v2_grad_grad,
ops::MatMulV2OpDoubleGrad,
ops::MatMulV2OpTripleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_v2_triple_grad, ops::MatMulV2OpTripleGrad);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
// only can include the headers in paddle/phi/api dirs
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
namespace paddle {
namespace operators {
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
phi::KernelKey GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const phi::KernelKey& expected_kernel_type) const override;
};
class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static phi::DenseTensor FoldInitDims(const phi::DenseTensor& input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return phi::make_ddim({1, x_dim[0]});
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return phi::make_ddim({y_dim[0], 1});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorIntoMatrixSequence(
phi::DenseTensor* x, const phi::funcs::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
static void ReshapeXYOutIntoMatrixSequence(phi::DenseTensor* x,
phi::DenseTensor* y,
phi::DenseTensor* out,
bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_,
mat_dim_y.width_});
}
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
} // namespace operators
} // namespace paddle
......@@ -1582,7 +1582,7 @@
out : Y
- op : matmul (matmul_v2)
backward : matmul_grad (matmul_v2_grad)
backward : matmul_grad (matmul_v2_grad), matmul_double_grad (matmul_v2_grad_grad), matmul_triple_grad (matmul_v2_triple_grad)
inputs :
{x : X, y : Y}
attrs :
......@@ -1591,6 +1591,7 @@
out : Out
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
complex_promote : [X, Y]
- op : matmul_with_flatten (mul)
backward : matmul_with_flatten_grad (mul_grad)
......
......@@ -120,6 +120,41 @@
param : [x, out_grad]
inplace : (out_grad -> x_grad)
- backward_op : matmul_double_grad
forward : matmul_grad (Tensor x, Tensor y, Tensor grad_out, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x, transpose_y, x_grad, y_grad, grad_out_grad)
optional : grad_x_grad, grad_y_grad
backward : matmul_triple_grad
- backward_op : matmul_grad
forward : matmul (Tensor x, Tensor y, bool transpose_x=false, bool transpose_y=false) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : matmul_grad
data_type: out_grad
backward : matmul_double_grad
- backward_op : matmul_triple_grad
forward : matmul_double_grad (Tensor x, Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad), Tensor(grad_grad_x_grad), Tensor(grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, grad_out, grad_grad_x, grad_grad_y]
kernel :
func : matmul_triple_grad
- backward_op : max_grad
forward: max (Tensor x, IntArray axis={0}, bool keepdim=false, bool reduce_all=false, int in_dtype=-1, int out_dtype=-1) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
......
......@@ -307,6 +307,15 @@
data_transform :
skip_transform : start, stop, number
- op : matmul
args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)
output : Tensor
infer_meta :
func : MatmulInferMeta
kernel :
func : matmul
backward : matmul_grad
- op : matrix_rank
args : (Tensor x, Tensor tol_tensor, float tol=0.0f, bool hermitian=false, bool use_default_tol=true)
output : Tensor(out)
......
......@@ -28,6 +28,29 @@ using phi::ReshapeToMatrix;
namespace phi {
KernelKey MatmulGetkernelTypeForVar(const GetKernelTypeForVarContext *ctx) {
const DenseTensor &tensor = ctx->GetTensor();
const KernelKey &expected_kernel_type = ctx->GetKernelKey();
if (phi::IsComplexType(expected_kernel_type.dtype())) {
// only promote inputs’s types when contains complex input
return phi::KernelKey(tensor.place(), tensor.layout(), tensor.dtype());
} else {
#ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.layout() == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) {
return phi::KernelKey(
tensor.place(), phi::DataLayout::kNHWC, expected_kernel_type.dtype());
}
#endif
return phi::KernelKey(
tensor.place(), tensor.layout(), expected_kernel_type.dtype());
}
}
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
......@@ -534,7 +557,9 @@ PD_REGISTER_KERNEL(matmul,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
uint8_t) {
kernel->get_kerneltype_forvar_fn_ = phi::MatmulGetkernelTypeForVar;
}
PD_REGISTER_KERNEL(matmul_with_flatten,
OneDNN,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MatmulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasAttr("use_addto")) {
return KernelSignature("addto_matmul_grad",
{"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y", "use_addto"},
{"X@GRAD", "Y@GRAD"});
} else {
return KernelSignature("matmul_grad",
{"X", "Y", "Out@GRAD"},
{"trans_x", "trans_y"},
{"X@GRAD", "Y@GRAD"});
}
}
KernelSignature MatmulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("matmul_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"},
{"DX", "DY", "DDOut"});
}
KernelSignature MatmulTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PD_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PD_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);
PD_REGISTER_ARG_MAPPING_FN(matmul_v2_grad, phi::MatmulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(matmul_v2_grad_grad,
phi::MatmulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(matmul_v2_triple_grad,
phi::MatmulTripleGradOpArgumentMapping);
......@@ -11,7 +11,7 @@ if(NOT (NOT WITH_PYTHON AND ON_INFER))
eager_scale
scale_node
generated_op
matmul_v2_op
generated_static_op
dygraph_function
eager_prim_api)
......
......@@ -9,15 +9,14 @@ if(WITH_TESTING AND NOT WIN32)
set(JIT_DEPS
phi
elementwise_add_op
matmul_v2_op
activation_op
reduce_mean_op
feed_op
fetch_op
generated_op
generated_static_op
transfer_layout_op
jit_layer)
jit_layer
generated_static_op)
cc_test(
layer_test
SRCS layer_test.cc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册