未验证 提交 88c03071 编写于 作者: C Chen Weihang 提交者: GitHub

polish trace op detail (#40425)

上级 6d830f6c
...@@ -61,7 +61,7 @@ the 2-D planes specified by dim1 and dim2. ...@@ -61,7 +61,7 @@ the 2-D planes specified by dim1 and dim2.
)DOC"); )DOC");
} }
}; };
class TraceOpGrad : public framework::OperatorWithKernel { class TraceGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -114,7 +114,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker, ...@@ -114,7 +114,7 @@ REGISTER_OPERATOR(trace, ops::TraceOp, ops::TraceOpMaker,
ops::TraceGradOpMaker<paddle::imperative::OpBase>, ops::TraceGradOpMaker<paddle::imperative::OpBase>,
TraceInferShapeFunctor); TraceInferShapeFunctor);
REGISTER_OPERATOR(trace_grad, ops::TraceOpGrad, REGISTER_OPERATOR(trace_grad, ops::TraceGradOp,
ops::TraceGradNoNeedBufferVarsInferer); ops::TraceGradNoNeedBufferVarsInferer);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -837,6 +837,7 @@ void TraceInferMeta( ...@@ -837,6 +837,7 @@ void TraceInferMeta(
sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
} }
out->set_dims(phi::make_ddim(sizes)); out->set_dims(phi::make_ddim(sizes));
out->set_dtype(x.dtype());
} }
void DiagonalInferMeta(const MetaTensor& input, void DiagonalInferMeta(const MetaTensor& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册