未验证 提交 570d0322 编写于 作者: S Sławomir Siwek 提交者: GitHub

matmul and matmul_v2 refactor (#42732)

* matmul refactor

* remove UT which only check ENFORCE output

* code format

* improve memory usage
上级 6f0a28f5
......@@ -378,20 +378,6 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < dim.size(); i++) {
dim_product *= dim.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
......@@ -686,76 +672,11 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));
// int num_negative = std::count(reshape_out.begin(), reshape_out.end(),
// -1);
// PADDLE_ENFORCE_LE(num_negative, 1,
// platform::errors::InvalidArgument(
// "The max number of -1 in fused_reshape_Out is 1 "
// "but received %d.",
// num_negative));
// auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0);
// if (it_zero != reshape_out.end()) {
// for (uint64_t i = 0; i < reshape_out.size(); i++) {
// if (reshape_out[i] == 0) {
// PADDLE_ENFORCE_LT(
// i, ddim_out.size(),
// platform::errors::InvalidArgument(
// "The index of 0 in fused_reshape_Out ",
// "should be less than output dim size, ",
// "but the index is %d and output dim size is %d", i,
// ddim_out.size()));
// reshape_out[i] = ddim_out.at(i);
// }
// }
// }
// if "-1" is present then one of reshape dims must be infered
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);
auto ddim_out_vec = phi::vectorize(ddim_out);
int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());
reshape_out[index] = ddim_out_product / reshape_out_product;
}
framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
context->SetOutputDim("Out", shape_out);
} else {
context->SetOutputDim("Out", ddim_out);
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
}
#else
context->SetOutputDim("Out", ddim_out);
#endif
context->ShareLoD("X", /*->*/ "Out");
context->SetOutputDim("Out", ddim_out);
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
......
......@@ -79,20 +79,6 @@ static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx,
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < dim.size(); i++) {
dim_product *= dim.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
......@@ -176,77 +162,12 @@ class MatMulV2Op : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));
// int num_negative = std::count(reshape_out.begin(), reshape_out.end(),
// -1);
// PADDLE_ENFORCE_LE(num_negative, 1,
// platform::errors::InvalidArgument(
// "The max number of -1 in fused_reshape_Out is 1 "
// "but received %d.",
// num_negative));
// auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0);
// if (it_zero != reshape_out.end()) {
// for (uint64_t i = 0; i < reshape_out.size(); i++) {
// if (reshape_out[i] == 0) {
// PADDLE_ENFORCE_LT(
// i, ddim_out.size(),
// platform::errors::InvalidArgument(
// "The index of 0 in fused_reshape_Out ",
// "should be less than output dim size, ",
// "but the index is %d and output dim size is %d", i,
// ddim_out.size()));
// reshape_out[i] = ddim_out.at(i);
// }
// }
// }
// if "-1" is present then one of reshape dims must be infered
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);
auto ddim_out_vec = phi::vectorize(ddim_out);
int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());
reshape_out[index] = ddim_out_product / reshape_out_product;
}
framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
ctx->SetOutputDim("Out", shape_out);
} else {
ctx->SetOutputDim("Out", ddim_out);
ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out);
}
#else
ctx->SetOutputDim("Out", ddim_out);
#endif
ctx->ShareLoD("X", /* --> */ "Out");
ctx->SetOutputDim("Out", ddim_out);
ctx->ShareLoD("X", "Out");
}
protected:
......
......@@ -257,20 +257,6 @@ class MatMulMKLDNNHandler
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < input_dims.size(); i++) {
dim_product *= input_dims.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
......@@ -299,20 +285,6 @@ class MatMulMKLDNNHandler
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < input_dims.size(); i++) {
dim_product *= input_dims.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
new_dims = input_dims.reshape(shape).transpose(axis);
}
......
......@@ -171,11 +171,21 @@ DDim stride_numel(const DDim& ddim) {
return strides;
}
DDim DDim::reshape(const std::vector<int>& shape) const {
DDim DDim::reshape(std::vector<int>& shape) const {
const int64_t copy_dim_val = 0;
const DDim& in_dims = *this;
DDim out_dims;
out_dims.rank_ = shape.size();
// Dim marked as "-1" must be inferred
auto it = std::find(shape.begin(), shape.end(), -1);
if (it != shape.end()) {
int index = std::distance(shape.begin(), it);
int reshape_out_product =
std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<int>());
shape[index] = product(in_dims) / reshape_out_product;
}
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(static_cast<int>(i),
......
......@@ -155,7 +155,7 @@ class DDim {
std::string to_str() const;
DDim reshape(const std::vector<int>& shape) const;
DDim reshape(std::vector<int>& shape) const;
DDim transpose(const std::vector<int>& axis) const;
......
......@@ -519,43 +519,6 @@ class TestMatMulOpTransposeReshapeOtherDimInt(
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 1, 2, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'supported transpose axis '
'for the fuse are {0, 2, 1, 3}')
class TestMatMulOpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'transpose_out supported rank is 4')
class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'reshape_out supported rank is 3')
if __name__ == "__main__":
from paddle import enable_static
enable_static()
......
......@@ -26,14 +26,11 @@ import paddle.fluid.framework as framework
from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeEmptyFloat,
TestMatMulOpTransposeReshapeBasicFloat,
TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException,
TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat,
TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat,
TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat,
TestReshapeTransposeMatMulOp3DXFloat, TestReshapeTransposeMatMulOp3DYFloat)
TestMatMulOpTransposeReshapeOtherDimFloat, TestReshapeTransposeMatMulOp,
TestReshapeTransposeMatMulOp4DXFloat, TestReshapeTransposeMatMulOp4DYFloat,
TestReshapeTransposeMatMulOp4DXYFloat, TestReshapeTransposeMatMulOp2DXFloat,
TestReshapeTransposeMatMulOp2DYFloat, TestReshapeTransposeMatMulOp3DXFloat,
TestReshapeTransposeMatMulOp3DYFloat)
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
......@@ -457,24 +454,6 @@ class TestMatMulV2OpTransposeReshapeOtherDimFloat(
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册