diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 2540af5d472c4916514cbfa8257af487b6de6b5f..9d381e1f22b5f9233f7d8e919f6680b28870ba94 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -378,20 +378,6 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, } } - // if "-1" is present then one of reshape dims must be infered - auto it_negative = std::find(shape.begin(), shape.end(), -1); - if (it_negative != shape.end()) { - int64_t dim_product = 1; - for (int i = 0; i < dim.size(); i++) { - dim_product *= dim.at(i); - } - - int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, - std::multiplies()); - int index = std::distance(shape.begin(), it_negative); - shape[index] = dim_product / shape_product; - } - dim = dim.reshape(shape).transpose(axis); } return dim; @@ -686,76 +672,11 @@ class MatMulOp : public framework::OperatorWithKernel { context->Attrs().Get>("fused_transpose_Out"); if (!reshape_out.empty() && !transpose_out.empty()) { - auto reshape_out_size = reshape_out.size(); - auto transpose_out_size = transpose_out.size(); - PADDLE_ENFORCE_EQ(transpose_out_size, 4, - platform::errors::InvalidArgument( - "transpose_out supported rank is 4, " - "received %d", - transpose_out_size)); - const std::vector supported_axis{0, 2, 1, 3}; - const bool supported_transpose_axis = std::equal( - transpose_out.begin(), transpose_out.end(), supported_axis.begin()); - PADDLE_ENFORCE_EQ( - supported_transpose_axis, true, - platform::errors::InvalidArgument( - "supported transpose axis for the fuse are {0, 2, 1, 3}")); - PADDLE_ENFORCE_EQ( - reshape_out_size, 3, - platform::errors::InvalidArgument("reshape_out supported rank is 3, " - "received %d", - reshape_out_size)); - - // int num_negative = std::count(reshape_out.begin(), reshape_out.end(), - // -1); - // PADDLE_ENFORCE_LE(num_negative, 1, - // platform::errors::InvalidArgument( - // "The max number of -1 in fused_reshape_Out is 1 " - // "but received %d.", - // num_negative)); - - // auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0); - // if (it_zero != reshape_out.end()) { - // for (uint64_t i = 0; i < reshape_out.size(); i++) { - // if (reshape_out[i] == 0) { - // PADDLE_ENFORCE_LT( - // i, ddim_out.size(), - // platform::errors::InvalidArgument( - // "The index of 0 in fused_reshape_Out ", - // "should be less than output dim size, ", - // "but the index is %d and output dim size is %d", i, - // ddim_out.size())); - // reshape_out[i] = ddim_out.at(i); - // } - // } - // } - - // if "-1" is present then one of reshape dims must be infered - auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); - if (it != reshape_out.end()) { - int index = std::distance(reshape_out.begin(), it); - - auto ddim_out_vec = phi::vectorize(ddim_out); - - int ddim_out_product = - std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, - std::multiplies()); - int reshape_out_product = std::accumulate( - reshape_out.begin(), reshape_out.end(), -1, std::multiplies()); - - reshape_out[index] = ddim_out_product / reshape_out_product; - } - - framework::DDim shape_out = - ddim_out.transpose(transpose_out).reshape(reshape_out); - context->SetOutputDim("Out", shape_out); - } else { - context->SetOutputDim("Out", ddim_out); + ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out); } -#else - context->SetOutputDim("Out", ddim_out); #endif - context->ShareLoD("X", /*->*/ "Out"); + context->SetOutputDim("Out", ddim_out); + context->ShareLoD("X", "Out"); } framework::OpKernelType GetExpectedKernelType( diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 55294331a9c85e980dfb4bb5d5fbfe650ef9b06c..162ebdafec1cb89c1531e03d78ecb383519d5357 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -79,20 +79,6 @@ static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, } } - // if "-1" is present then one of reshape dims must be infered - auto it_negative = std::find(shape.begin(), shape.end(), -1); - if (it_negative != shape.end()) { - int64_t dim_product = 1; - for (int i = 0; i < dim.size(); i++) { - dim_product *= dim.at(i); - } - - int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, - std::multiplies()); - int index = std::distance(shape.begin(), it_negative); - shape[index] = dim_product / shape_product; - } - dim = dim.reshape(shape).transpose(axis); } return dim; @@ -176,77 +162,12 @@ class MatMulV2Op : public framework::OperatorWithKernel { ctx->Attrs().Get>("fused_transpose_Out"); if (!reshape_out.empty() && !transpose_out.empty()) { - auto reshape_out_size = reshape_out.size(); - auto transpose_out_size = transpose_out.size(); - PADDLE_ENFORCE_EQ(transpose_out_size, 4, - platform::errors::InvalidArgument( - "transpose_out supported rank is 4, " - "received %d", - transpose_out_size)); - const std::vector supported_axis{0, 2, 1, 3}; - const bool supported_transpose_axis = std::equal( - transpose_out.begin(), transpose_out.end(), supported_axis.begin()); - PADDLE_ENFORCE_EQ( - supported_transpose_axis, true, - platform::errors::InvalidArgument( - "supported transpose axis for the fuse are {0, 2, 1, 3}")); - PADDLE_ENFORCE_EQ( - reshape_out_size, 3, - platform::errors::InvalidArgument("reshape_out supported rank is 3, " - "received %d", - reshape_out_size)); - - // int num_negative = std::count(reshape_out.begin(), reshape_out.end(), - // -1); - // PADDLE_ENFORCE_LE(num_negative, 1, - // platform::errors::InvalidArgument( - // "The max number of -1 in fused_reshape_Out is 1 " - // "but received %d.", - // num_negative)); - - // auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0); - // if (it_zero != reshape_out.end()) { - // for (uint64_t i = 0; i < reshape_out.size(); i++) { - // if (reshape_out[i] == 0) { - // PADDLE_ENFORCE_LT( - // i, ddim_out.size(), - // platform::errors::InvalidArgument( - // "The index of 0 in fused_reshape_Out ", - // "should be less than output dim size, ", - // "but the index is %d and output dim size is %d", i, - // ddim_out.size())); - // reshape_out[i] = ddim_out.at(i); - // } - // } - // } - - // if "-1" is present then one of reshape dims must be infered - auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); - if (it != reshape_out.end()) { - int index = std::distance(reshape_out.begin(), it); - - auto ddim_out_vec = phi::vectorize(ddim_out); - - int ddim_out_product = - std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, - std::multiplies()); - int reshape_out_product = std::accumulate( - reshape_out.begin(), reshape_out.end(), -1, std::multiplies()); - - reshape_out[index] = ddim_out_product / reshape_out_product; - } - - framework::DDim shape_out = - ddim_out.transpose(transpose_out).reshape(reshape_out); - ctx->SetOutputDim("Out", shape_out); - } else { - ctx->SetOutputDim("Out", ddim_out); + ddim_out = ddim_out.transpose(transpose_out).reshape(reshape_out); } -#else - ctx->SetOutputDim("Out", ddim_out); #endif - ctx->ShareLoD("X", /* --> */ "Out"); + ctx->SetOutputDim("Out", ddim_out); + ctx->ShareLoD("X", "Out"); } protected: diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index f4137733e300e9650fe9948d8f20296aed2e9857..e9abe84e67980377c254fdcadae6a6e764acb869 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -257,20 +257,6 @@ class MatMulMKLDNNHandler } } - // if "-1" is present then one of reshape dims must be infered - auto it_negative = std::find(shape.begin(), shape.end(), -1); - if (it_negative != shape.end()) { - int64_t dim_product = 1; - for (int i = 0; i < input_dims.size(); i++) { - dim_product *= input_dims.at(i); - } - - int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, - std::multiplies()); - int index = std::distance(shape.begin(), it_negative); - shape[index] = dim_product / shape_product; - } - return input_dims.reshape(shape).transpose(axis); } return input_dims; @@ -299,20 +285,6 @@ class MatMulMKLDNNHandler } } - // if "-1" is present then one of reshape dims must be infered - auto it_negative = std::find(shape.begin(), shape.end(), -1); - if (it_negative != shape.end()) { - int64_t dim_product = 1; - for (int i = 0; i < input_dims.size(); i++) { - dim_product *= input_dims.at(i); - } - - int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, - std::multiplies()); - int index = std::distance(shape.begin(), it_negative); - shape[index] = dim_product / shape_product; - } - new_dims = input_dims.reshape(shape).transpose(axis); } diff --git a/paddle/phi/core/ddim.cc b/paddle/phi/core/ddim.cc index e6bf81590f158ca721b5421c62b49478c93a7024..1809c413bc146caf4c1a94d92d52a5fd5ccc96f5 100644 --- a/paddle/phi/core/ddim.cc +++ b/paddle/phi/core/ddim.cc @@ -171,11 +171,21 @@ DDim stride_numel(const DDim& ddim) { return strides; } -DDim DDim::reshape(const std::vector& shape) const { +DDim DDim::reshape(std::vector& shape) const { const int64_t copy_dim_val = 0; const DDim& in_dims = *this; DDim out_dims; out_dims.rank_ = shape.size(); + + // Dim marked as "-1" must be inferred + auto it = std::find(shape.begin(), shape.end(), -1); + if (it != shape.end()) { + int index = std::distance(shape.begin(), it); + int reshape_out_product = + std::accumulate(shape.begin(), shape.end(), -1, std::multiplies()); + shape[index] = product(in_dims) / reshape_out_product; + } + for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] == copy_dim_val) { PADDLE_ENFORCE_LT(static_cast(i), diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h index ce462d8d954023a1ccd2ff4d33e1cf9611b40513..dd13081ddafffab4557f03fd722d0d31021fb1db 100644 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -155,7 +155,7 @@ class DDim { std::string to_str() const; - DDim reshape(const std::vector& shape) const; + DDim reshape(std::vector& shape) const; DDim transpose(const std::vector& axis) const; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py index d13012ee33847eecc4306f4bfb01257406fe74c3..634288c3e875b6808cf3e6fdb3d78a97a7894412 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -519,43 +519,6 @@ class TestMatMulOpTransposeReshapeOtherDimInt( self.data_type_ = np.int8 -class TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException( - TestMatMulOpTransposeReshapeBasicFloat): - def init_params_and_out(self): - self.transpose_out = [0, 1, 2, 3] - self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]] - self.out = np.matmul(self.x, self.y) - - def test_check_output(self): - self.assertRaises(AttributeError, self.check_raise_error, - 'supported transpose axis ' - 'for the fuse are {0, 2, 1, 3}') - - -class TestMatMulOpTransposeReshapeTransposeRankNotSupportedException( - TestMatMulOpTransposeReshapeBasicFloat): - def init_params_and_out(self): - self.transpose_out = [0, 2, 1] - self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]] - self.out = np.matmul(self.x, self.y) - - def test_check_output(self): - self.assertRaises(AttributeError, self.check_raise_error, - 'transpose_out supported rank is 4') - - -class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException( - TestMatMulOpTransposeReshapeBasicFloat): - def init_params_and_out(self): - self.transpose_out = [0, 2, 1, 3] - self.reshape_out = [0, 0] - self.out = np.matmul(self.x, self.y) - - def test_check_output(self): - self.assertRaises(AttributeError, self.check_raise_error, - 'reshape_out supported rank is 3') - - if __name__ == "__main__": from paddle import enable_static enable_static() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 4e59e41b60851d5ed32adf72d8095dedb203b0d8..69cee49c3ec618e04dc07e6ec0727bbfaec23ce9 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -26,14 +26,11 @@ import paddle.fluid.framework as framework from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import ( TestMatMulOpTransposeReshapeEmptyFloat, TestMatMulOpTransposeReshapeBasicFloat, - TestMatMulOpTransposeReshapeOtherDimFloat, - TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException, - TestMatMulOpTransposeReshapeTransposeRankNotSupportedException, - TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException, - TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat, - TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat, - TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat, - TestReshapeTransposeMatMulOp3DXFloat, TestReshapeTransposeMatMulOp3DYFloat) + TestMatMulOpTransposeReshapeOtherDimFloat, TestReshapeTransposeMatMulOp, + TestReshapeTransposeMatMulOp4DXFloat, TestReshapeTransposeMatMulOp4DYFloat, + TestReshapeTransposeMatMulOp4DXYFloat, TestReshapeTransposeMatMulOp2DXFloat, + TestReshapeTransposeMatMulOp2DYFloat, TestReshapeTransposeMatMulOp3DXFloat, + TestReshapeTransposeMatMulOp3DYFloat) def reference_matmul(X, Y, transpose_x=False, transpose_y=False): @@ -457,24 +454,6 @@ class TestMatMulV2OpTransposeReshapeOtherDimFloat( self.op_type = "matmul_v2" -class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException( - TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException): - def set_op_type(self): - self.op_type = "matmul_v2" - - -class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException( - TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException): - def set_op_type(self): - self.op_type = "matmul_v2" - - -class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException( - TestMatMulOpTransposeReshapeTransposeRankNotSupportedException): - def set_op_type(self): - self.op_type = "matmul_v2" - - class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp): def set_op_type_and_transpose_y_name(self): self.op_type = "matmul_v2"