未验证 提交 12ca438e 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Rename general grad infermeta func (#39578)

* rename general grad infermeta func

* remove useless code
上级 6b756fb7
......@@ -527,7 +527,7 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor,
PT_INFER_META(pten::MatmulGradInferMeta));
PT_INFER_META(pten::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
......
......@@ -16,13 +16,10 @@ limitations under the License. */
namespace pten {
void MatmulGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out_grad_meta,
bool transpose_x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy) {
void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
MetaTensor* dy) {
if (dx) {
dx->share_meta(x);
}
......
......@@ -20,12 +20,9 @@ limitations under the License. */
namespace pten {
void MatmulGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out_grad_meta,
bool transpose_x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy);
void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
MetaTensor* dy);
} // namespace pten
......@@ -3,7 +3,8 @@
args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : MatmulGradInferMeta
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : matmul_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册