未验证 提交 acb78ea2 编写于 作者: S Sławomir Siwek 提交者: GitHub

Offload calculations from matmul op to fuse pass (#44941)

* remove v2_transpose_reshape

* matmul_transpose_reshape

* reshape_transpose_matmul

* Add int8 support for matmulV2

* restore ut

* adjust old ut

* restore parallel UT ruels

* remove mkldnn code from base ops

* move enforces to pass

* remove duplicated functions

* delete duplicated enforces

* feedback from review

* add comments to variables

* enable eltwise support

* dynamic attribute

* remove fusepass tests from op test

* remove fuse pass cases from op test

* revert introduction of dynamic attributes

* style
Co-authored-by: Nwozna <joanna.wozna@intel.com>
上级 c737232f
......@@ -66,27 +66,23 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
auto transpose_axis =
PADDLE_GET_CONST(std::vector<int>, transpose_op->Op()->GetAttr("axis"));
auto reshape_out_size = reshape_shape.size();
auto transpose_out_size = transpose_axis.size();
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
if (transpose_axis != supported_axis) {
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
if (reshape_shape.size() != 3) {
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
<< reshape_shape.size();
return;
}
if (std::count(reshape_shape.begin(), reshape_shape.end(), -1) > 1) {
VLOG(3) << "Only one dim can be undefined / marked as '-1'";
return;
}
OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
......
......@@ -23,14 +23,24 @@ namespace ir {
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"};
bool with_reshape_xshape = true;
bool with_transpose_xshape = true;
for (const auto &matmul_type : matmul_types) {
Fuse(graph, matmul_type, with_reshape_xshape, with_transpose_xshape);
Fuse(graph, matmul_type, with_reshape_xshape, !with_transpose_xshape);
Fuse(graph, matmul_type, !with_reshape_xshape, with_transpose_xshape);
Fuse(graph, matmul_type, !with_reshape_xshape, !with_transpose_xshape);
Fuse(graph,
matmul_type,
false /*with_reshape_xshape*/,
false /*with_transpose_xshape*/);
Fuse(graph,
matmul_type,
false /*with_reshape_xshape*/,
true /*with_transpose_xshape*/);
Fuse(graph,
matmul_type,
true /*with_reshape_xshape*/,
false /*with_transpose_xshape*/);
Fuse(graph,
matmul_type,
true /*with_reshape_xshape*/,
true /*with_transpose_xshape*/);
}
}
......@@ -80,29 +90,44 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern);
auto reshape_shape =
paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
paddle::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
OpDesc *matmul_desc = matmul_op->Op();
std::string input_var_name = transpose_out->Name();
auto UpdateMatmul = [&](std::string matmul_input_name) {
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
transpose_axis);
};
std::string matmul_input_name;
if (matmul_desc->Inputs().at("X").at(0) == input_var_name) {
UpdateMatmul("X");
matmul_input_name = "X";
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y");
matmul_input_name = "Y";
} else {
throw platform::errors::InvalidArgument("Unexpected input to " +
matmul_type + " encountered.");
}
auto reshape_shape =
paddle::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
paddle::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
if (reshape_shape.size() < 2 || reshape_shape.size() > 4) {
VLOG(3) << "shape_" + matmul_input_name + " attribute of " + matmul_type +
" was implemented for 2, 3 or 4 dimensions.";
return;
}
if (reshape_shape.size() != transpose_axis.size()) {
VLOG(3) << "Ranks of shape_" + matmul_input_name + " and axis_" +
matmul_input_name + "attributes of " + matmul_type +
" must be equal.";
return;
}
if (std::count(reshape_shape.begin(), reshape_shape.end(), -1) > 1) {
VLOG(3) << "Only one dim can be undefined / marked as '-1'";
return;
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
transpose_axis);
std::unordered_set<const ir::Node *> nodes_to_remove{
reshape_op, reshape_out, transpose_op, transpose_out};
if (with_reshape_xshape) nodes_to_remove.insert(reshape_xshape);
......
......@@ -358,58 +358,7 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
"shape of Input(%s) = [%s].",
dim));
// if mkldnn reshape+transpose+matmul fuse activated
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(),
2,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_LE(
shape.size(),
4,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_EQ(
shape.size(),
axis.size(),
platform::errors::InvalidArgument(
"Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.",
input_name,
input_name));
int num_negative = std::count(shape.begin(), shape.end(), -1);
PADDLE_ENFORCE_LE(num_negative,
1,
platform::errors::InvalidArgument(
"The max number of -1 in fused_reshape_%s is 1 "
"but received %d.",
input_name,
num_negative));
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
dim.size(),
platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
dim.size()));
shape[i] = dim.at(i);
}
}
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
......@@ -732,14 +681,11 @@ class MatMulOp : public framework::OperatorWithKernel {
framework::DDim ddim_out = phi::make_ddim(dim_out);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul+transpose+reshape fuse activated
auto reshape_out =
context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto transpose_out =
context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
auto shape = context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
context->SetOutputDim("Out", ddim_out);
......
......@@ -38,58 +38,7 @@ static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx,
"shape of Input(%s) = [%s].",
dim));
// if mkldnn reshape+transpose+matmul fuse activated
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(),
2,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_LE(
shape.size(),
4,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_EQ(
shape.size(),
axis.size(),
platform::errors::InvalidArgument(
"Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.",
input_name,
input_name));
int num_negative = std::count(shape.begin(), shape.end(), -1);
PADDLE_ENFORCE_LE(num_negative,
1,
platform::errors::InvalidArgument(
"The max number of -1 in fused_reshape_%s is 1 "
"but received %d.",
input_name,
num_negative));
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
dim.size(),
platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
dim.size()));
shape[i] = dim.at(i);
}
}
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
......@@ -169,13 +118,11 @@ class MatMulV2Op : public framework::OperatorWithKernel {
auto ddim_out = phi::make_ddim(new_dims);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul_v2+transpose+reshape fuse activated
auto reshape_out = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto transpose_out =
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
auto shape = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
......
......@@ -108,24 +108,6 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
......@@ -225,12 +207,7 @@ class MatMulMKLDNNHandler
src_memory_p->set_data_handle(x_ptr);
weights_memory_p->set_data_handle(y_ptr);
dst_memory_p->set_data_handle(out_ptr);
matmul_p->execute(astream,
{
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p},
});
matmul_p->execute(astream, matmul_args);
x_ptr = static_cast<char *>(x_ptr) + std::get<0>(offsets);
y_ptr = static_cast<char *>(y_ptr) + std::get<1>(offsets);
out_ptr = static_cast<char *>(out_ptr) + std::get<2>(offsets);
......@@ -270,25 +247,6 @@ class MatMulMKLDNNHandler
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
new_dims = input_dims.reshape(shape).transpose(axis);
}
......@@ -519,7 +477,10 @@ static void ExecuteMatMul(const ExecutionContext &ctx) {
constexpr bool is_int8 = IsInt8<XT>();
constexpr bool is_bfloat16 = IsBfloat16<XT>();
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
const bool fuse_relu =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation") == "relu"
: false;
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
......@@ -596,23 +557,6 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i,
input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name,
i,
input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
new_dims = input_dims.reshape(shape).transpose(axis);
}
......
......@@ -172,10 +172,13 @@ DDim stride_numel(const DDim& ddim) {
}
DDim DDim::reshape(std::vector<int>& shape) const {
const int64_t copy_dim_val = 0;
const DDim& in_dims = *this;
DDim out_dims;
out_dims.rank_ = shape.size();
for (uint64_t i = 0; i < shape.size(); ++i) {
if (shape[i] == 0) {
shape[i] = in_dims.at(i);
}
}
// Dim marked as "-1" must be inferred
auto it = std::find(shape.begin(), shape.end(), -1);
......@@ -186,54 +189,14 @@ DDim DDim::reshape(std::vector<int>& shape) const {
shape[index] = product(in_dims) / reshape_out_product;
}
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(static_cast<int>(i),
in_dims.size(),
phi::errors::InvalidArgument(
"Index %d of shape under which the value of 0 "
"is stored, must be lower than the number of "
"old dimensions. But received shape[%d] = 0, "
"dimensions = %d, shape = [%s].",
i,
in_dims.size(),
in_dims));
out_dims[i] = in_dims[i];
} else {
out_dims[i] = shape[i];
}
}
return out_dims;
return phi::make_ddim(shape);
}
DDim DDim::transpose(const std::vector<int>& axis) const {
const DDim& in_dims = *this;
size_t in_rank = in_dims.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
DDim out_dims(in_dims);
for (size_t i = 0; i < axis_size; i++) {
for (size_t i = 0; i < axis.size(); i++) {
out_dims[i] = in_dims[axis[i]];
}
return out_dims;
......
......@@ -31,9 +31,9 @@ class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
input_dim = draw(st.sampled_from([32]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt',
'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh',
'hard_sigmoid', 'leaky_relu'
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
def generate_input(type):
......
......@@ -30,9 +30,9 @@ class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest):
input_dim = draw(st.sampled_from([16, 32, 64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt',
'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh',
'hard_sigmoid', 'leaky_relu'
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
def generate_input(type):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册