diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index eaec4b78f4fc0401c907fe0481d9b9e1da1b8ff4..40f2b625f65006061f24779c0aee2b92ec297890 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -527,7 +527,7 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ops::MatMulV2GradOpMaker); DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor, - PT_INFER_META(pten::MatmulGradInferMeta)); + PT_INFER_META(pten::GeneralBinaryGradInferMeta)); REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, ops::MatMulV2OpDoubleGradMaker, diff --git a/paddle/pten/infermeta/backward.cc b/paddle/pten/infermeta/backward.cc index db92449519436024a01c9c891f9671756777a345..2f2fcc7db31ea51f2111103675bbd20e7ab1ec58 100644 --- a/paddle/pten/infermeta/backward.cc +++ b/paddle/pten/infermeta/backward.cc @@ -16,13 +16,10 @@ limitations under the License. */ namespace pten { -void MatmulGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad_meta, - bool transpose_x, - bool transpose_y, - MetaTensor* dx, - MetaTensor* dy) { +void GeneralBinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy) { if (dx) { dx->share_meta(x); } diff --git a/paddle/pten/infermeta/backward.h b/paddle/pten/infermeta/backward.h index d6b96861412861de6fb892a28c3930bd7db20da7..ded51cac6378c574232eed3e641def23c68c3db8 100644 --- a/paddle/pten/infermeta/backward.h +++ b/paddle/pten/infermeta/backward.h @@ -20,12 +20,9 @@ limitations under the License. */ namespace pten { -void MatmulGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad_meta, - bool transpose_x, - bool transpose_y, - MetaTensor* dx, - MetaTensor* dy); +void GeneralBinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy); } // namespace pten diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index d14cf11c8dd7eaea2482e7a043c76530fc6fc7d7..b1227217b3cab5102c4b9555abc2f93a7f8e82fa 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -3,7 +3,8 @@ args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : MatmulGradInferMeta + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : matmul_grad