未验证 提交 88c03071 编写于 作者: C Chen Weihang 提交者: GitHub

polish trace op detail (#40425)

上级 6d830f6c
......@@ -61,7 +61,7 @@ the 2-D planes specified by dim1 and dim2.
)DOC");
}
};
class TraceOpGrad : public framework::OperatorWithKernel {
class TraceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -114,7 +114,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::imperative::OpBase>,
TraceInferShapeFunctor);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad,
REGISTER_OPERATOR(trace_grad, ops::TraceGradOp,
ops::TraceGradNoNeedBufferVarsInferer);
/* ========================== register checkpoint ===========================*/
......
......@@ -837,6 +837,7 @@ void TraceInferMeta(
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
}
out->set_dims(phi::make_ddim(sizes));
out->set_dtype(x.dtype());
}
void DiagonalInferMeta(const MetaTensor& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册