未验证 提交 a09a93a1 编写于 作者: C Chen Weihang 提交者: GitHub

move determinant op infershape (#40624)

上级 bef6f2e1
......@@ -13,6 +13,10 @@
// limitations under the License.
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -20,11 +24,6 @@ namespace operators {
class DeterminantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant");
}
};
class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DeterminantGradOp");
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -162,11 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(determinant, DeterminantInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,
ops::DeterminantGradOpMaker<paddle::framework::OpDesc>,
ops::DeterminantGradOpMaker<paddle::imperative::OpBase>);
ops::DeterminantGradOpMaker<paddle::imperative::OpBase>,
DeterminantInferShapeFunctor);
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp)
DECLARE_INFER_SHAPE_FUNCTOR(determinant_grad, DeterminantGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp,
DeterminantGradInferShapeFunctor);
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantOpMaker,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册