未验证 提交 12ca438e 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Rename general grad infermeta func (#39578)

* rename general grad infermeta func

* remove useless code
上级 6b756fb7
...@@ -527,7 +527,7 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker, ...@@ -527,7 +527,7 @@ REGISTER_OPERATOR(matmul_v2, ops::MatMulV2Op, ops::MatMulV2OpMaker,
ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>); ops::MatMulV2GradOpMaker<paddle::imperative::OpBase>);
DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor, DELCARE_INFER_SHAPE_FUNCTOR(matmul_v2_grad, MatMulV2GradInferShapeFunctor,
PT_INFER_META(pten::MatmulGradInferMeta)); PT_INFER_META(pten::GeneralBinaryGradInferMeta));
REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad,
ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>, ops::MatMulV2OpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>, ops::MatMulV2OpDoubleGradMaker<paddle::imperative::OpBase>,
......
...@@ -16,13 +16,10 @@ limitations under the License. */ ...@@ -16,13 +16,10 @@ limitations under the License. */
namespace pten { namespace pten {
void MatmulGradInferMeta(const MetaTensor& x, void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& out_grad_meta, MetaTensor* dx,
bool transpose_x, MetaTensor* dy) {
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy) {
if (dx) { if (dx) {
dx->share_meta(x); dx->share_meta(x);
} }
......
...@@ -20,12 +20,9 @@ limitations under the License. */ ...@@ -20,12 +20,9 @@ limitations under the License. */
namespace pten { namespace pten {
void MatmulGradInferMeta(const MetaTensor& x, void GeneralBinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
const MetaTensor& out_grad_meta, MetaTensor* dx,
bool transpose_x, MetaTensor* dy);
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy);
} // namespace pten } // namespace pten
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false) args : (const Tensor& x, const Tensor& y, const Tensor& out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : MatmulGradInferMeta func : GeneralBinaryGradInferMeta
param : [x, y]
kernel : kernel :
func : matmul_grad func : matmul_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册