diff --git a/paddle/fluid/operators/trace_op.cc b/paddle/fluid/operators/trace_op.cc index 0590b66f6f868858d66e95382f96c8ad42ac64c2..c6c0fa3c0019eac742a9c70ea53a438f5a474895 100644 --- a/paddle/fluid/operators/trace_op.cc +++ b/paddle/fluid/operators/trace_op.cc @@ -61,7 +61,7 @@ the 2-D planes specified by dim1 and dim2. )DOC"); } }; -class TraceOpGrad : public framework::OperatorWithKernel { +class TraceGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -114,7 +114,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ops::TraceGradOpMaker, TraceInferShapeFunctor); -REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, +REGISTER_OPERATOR(trace_grad, ops::TraceGradOp, ops::TraceGradNoNeedBufferVarsInferer); /* ========================== register checkpoint ===========================*/ diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d6d4efad9fae26e6cbdb914752d8c24ab23d948c..9daad7d6aaa9f5af70b4b7c3b4bfa96bc351194b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -837,6 +837,7 @@ void TraceInferMeta( sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); } out->set_dims(phi::make_ddim(sizes)); + out->set_dtype(x.dtype()); } void DiagonalInferMeta(const MetaTensor& input,