From 12ca438e85f1dade8ba642ab03bd2f1e72a02d88 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Feb 2022 10:43:34 +0800 Subject: [PATCH] [PTen] Rename general grad infermeta func (#39578) * rename general grad infermeta func * remove useless code --- paddle/fluid/operators/matmul_v2_op.cc | 2 +- paddle/pten/infermeta/backward.cc | 11 ++++------- paddle/pten/infermeta/backward.h | 11 ++++------- python/paddle/utils/code_gen/backward.yaml | 3 ++- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index eaec4b78f4f..40f2b625f65 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -527,7 +527,7 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ops::MatMulV2GradOpMaker); DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor, - PT_INFER_META(pten::MatmulGradInferMeta)); + PT_INFER_META(pten::GeneralBinaryGradInferMeta)); REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, ops::MatMulV2OpDoubleGradMaker, diff --git a/paddle/pten/infermeta/backward.cc b/paddle/pten/infermeta/backward.cc index db924495194..2f2fcc7db31 100644 --- a/paddle/pten/infermeta/backward.cc +++ b/paddle/pten/infermeta/backward.cc @@ -16,13 +16,10 @@ limitations under the License. */ namespace pten { -void MatmulGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad_meta, - bool transpose_x, - bool transpose_y, - MetaTensor* dx, - MetaTensor* dy) { +void GeneralBinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy) { if (dx) { dx->share_meta(x); } diff --git a/paddle/pten/infermeta/backward.h b/paddle/pten/infermeta/backward.h index d6b96861412..ded51cac637 100644 --- a/paddle/pten/infermeta/backward.h +++ b/paddle/pten/infermeta/backward.h @@ -20,12 +20,9 @@ limitations under the License. */ namespace pten { -void MatmulGradInferMeta(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out_grad_meta, - bool transpose_x, - bool transpose_y, - MetaTensor* dx, - MetaTensor* dy); +void GeneralBinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy); } // namespace pten diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index d14cf11c8dd..b1227217b3c 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -3,7 +3,8 @@ args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) output : Tensor(x_grad), Tensor(y_grad) infer_meta : - func : MatmulGradInferMeta + func : GeneralBinaryGradInferMeta + param : [x, y] kernel : func : matmul_grad -- GitLab