未验证 提交 bc3afd82 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

Add output defs for fused_matmul kernel (#51326)

* remove fused_matmul from list

* add infermeta for fused matmul
上级 059699a2
......@@ -63,7 +63,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"eigh",
"ftt_c2r",
"ftt_r2c",
"fused_matmul",
"generate_proposals",
"graph_sample_neighbors",
"group_norm",
......
......@@ -12,7 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -37,96 +44,6 @@ static std::vector<int64_t> GetInputShape(phi::DDim dim,
class FusedMatmulOp : public MatMulV2Op {
public:
using MatMulV2Op::MatMulV2Op;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fused_matmul");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fused_matmul");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fused_matmul");
bool trans_x = ctx->Attrs().Get<bool>("trans_x");
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x =
GetInputShape(ctx->GetInputDim("X"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_X"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_X"));
std::vector<int64_t> dims_y =
GetInputShape(ctx->GetInputDim("Y"),
ctx->Attrs().Get<std::vector<int>>("fused_reshape_Y"),
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims);
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
auto is_output_fused = (!shape.empty() && !axis.empty());
if (is_output_fused) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
}
};
class FusedMatmulOpMaker : public MatMulV2OpMaker {
......@@ -198,9 +115,13 @@ class FusedMatmulOpMaker : public MatMulV2OpMaker {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(fused_matmul,
FusedMatmulInferShapeFunctor,
PD_INFER_META(phi::FusedMatmulInferMeta));
REGISTER_OPERATOR(
fused_matmul,
ops::FusedMatmulOp,
ops::FusedMatmulOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
FusedMatmulInferShapeFunctor);
......@@ -1297,6 +1297,138 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
// Used in FusedMatmulInferMeta
static std::vector<int64_t> GetInputShape(phi::DDim dim,
std::vector<int> shape,
std::vector<int> axis) {
PADDLE_ENFORCE_GT(dim.size(),
0,
phi::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
auto is_input_fused = (!shape.empty() && !axis.empty());
if (is_input_fused) {
dim = dim.reshape(shape).transpose(axis);
}
return phi::vectorize(dim);
}
void FusedMatmulInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& residual_data,
bool transpose_x,
bool transpose_y,
const float matmul_alpha,
const std::string& fuse_activation,
const float fuse_lapha,
const float fuse_beat,
const float fused_output_scale,
const std::vector<int>& fused_reshape_X,
const std::vector<int>& fused_transpose_X,
const std::vector<int>& fused_reshape_Y,
const std::vector<int>& fused_transpose_Y,
const std::vector<int>& fused_reshape_Out,
const std::vector<int>& fused_transpose_Out,
const std::string& mkldnn_data_type,
const float scale_x,
const float scale_y,
const float scale_scale_in_eltwise,
const float scale_out,
const bool force_fp32_output,
MetaTensor* out) {
std::vector<int64_t> dims_x =
GetInputShape(x.dims(), fused_reshape_X, fused_transpose_X);
std::vector<int64_t> dims_y =
GetInputShape(y.dims(), fused_reshape_Y, fused_transpose_Y);
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
phi::errors::InvalidArgument(
"The Input(X) dims size must be greater than 0,"
" but received dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
phi::errors::InvalidArgument(
"The Input(Y) dims size must be greater than 0,"
" but received dims size is 0. "));
bool x_broadcasted = false;
bool y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}
if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}
size_t M, N;
if (transpose_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (transpose_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}
std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto ddim_out = phi::make_ddim(new_dims);
std::vector<int> shape = fused_reshape_Out;
const std::vector<int>& axis = fused_transpose_Out;
auto is_output_fused = (!shape.empty() && !axis.empty());
if (is_output_fused) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
out->set_dims(ddim_out);
bool is_int8 = (x.dtype() == DataType::UINT8 || x.dtype() == DataType::INT8);
bool is_bfloat16 = (x.dtype() == DataType::BFLOAT16);
bool fuse_relu = false;
if (fuse_activation == "relu" || fuse_activation == "relu6") {
fuse_relu = true;
}
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
out->set_dtype(DataType::FLOAT32);
} else if (is_bfloat16) {
out->set_dtype(DataType::BFLOAT16);
} else if (fuse_relu) {
out->set_dtype(DataType::UINT8);
} else {
out->set_dtype(DataType::INT8);
}
}
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......
......@@ -222,6 +222,30 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x,
int dim2,
MetaTensor* out);
void FusedMatmulInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& residual_data,
bool transpose_x,
bool transpose_y,
const float matmul_alpha,
const std::string& fuse_activation,
const float fuse_lapha,
const float fuse_beat,
const float fused_output_scale,
const std::vector<int>& fused_reshape_X,
const std::vector<int>& fused_transpose_X,
const std::vector<int>& fused_reshape_Y,
const std::vector<int>& fused_transpose_Y,
const std::vector<int>& fused_reshape_Out,
const std::vector<int>& fused_transpose_Out,
const std::string& mkldnn_data_type,
const float scale_x,
const float scale_y,
const float scale_scale_in_eltwise,
const float scale_out,
const bool force_fp32_output,
MetaTensor* out);
void GatherInferMeta(const MetaTensor& x,
const MetaTensor& index,
const Scalar& axis,
......
......@@ -523,4 +523,6 @@ PD_REGISTER_KERNEL(fused_matmul,
float,
phi::dtype::bfloat16,
int8_t,
uint8_t) {}
uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册