未验证 提交 a09a93a1 编写于 作者: C Chen Weihang 提交者: GitHub

move determinant op infershape (#40624)

上级 bef6f2e1
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/determinant_op.h" #include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,11 +24,6 @@ namespace operators { ...@@ -20,11 +24,6 @@ namespace operators {
class DeterminantOp : public framework::OperatorWithKernel { class DeterminantOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant");
}
}; };
class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel { ...@@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DeterminantGradOp");
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -162,11 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, ...@@ -162,11 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(determinant, DeterminantInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,
ops::DeterminantGradOpMaker<paddle::framework::OpDesc>, ops::DeterminantGradOpMaker<paddle::framework::OpDesc>,
ops::DeterminantGradOpMaker<paddle::imperative::OpBase>); ops::DeterminantGradOpMaker<paddle::imperative::OpBase>,
DeterminantInferShapeFunctor);
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp) DECLARE_INFER_SHAPE_FUNCTOR(determinant_grad, DeterminantGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp,
DeterminantGradInferShapeFunctor);
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantOpMaker, ops::SlogDeterminantOpMaker,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册