From 88c03071cda368844f84305b25c28c5071d91965 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Fri, 11 Mar 2022 18:10:52 +0800 Subject: [PATCH] polish trace op detail (#40425) --- paddle/fluid/operators/trace_op.cc | 4 ++-- paddle/phi/infermeta/unary.cc | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index 0590b66f6f8..c6c0fa3c001 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -61,7 +61,7 @@ the 2-D planes specified by dim1 and dim2. )DOC"); } }; -class TraceOpGrad : public framework::OperatorWithKernel { +class TraceGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -114,7 +114,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ops::TraceGradOpMaker, TraceInferShapeFunctor); -REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, +REGISTER_OPERATOR(trace_grad, ops::TraceGradOp, ops::TraceGradNoNeedBufferVarsInferer); /* ========================== register checkpoint ===========================*/ diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d6d4efad9fa..9daad7d6aaa 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -837,6 +837,7 @@ void TraceInferMeta( sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); } out->set_dims(phi::make_ddim(sizes)); + out->set_dtype(x.dtype()); } void DiagonalInferMeta(const MetaTensor& input, -- GitLab