未验证 提交 c7637700 编写于 作者: C Chen Weihang 提交者: GitHub

move isclose infershape (#40595)

上级 8fd20b5b
...@@ -14,10 +14,13 @@ ...@@ -14,10 +14,13 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -60,40 +63,6 @@ class IscloseOp : public framework::OperatorWithKernel { ...@@ -60,40 +63,6 @@ class IscloseOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose");
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose");
auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other");
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
platform::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = ctx->IsRuntime();
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
}
}
}
ctx->SetOutputDim("Out", input_dim);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -115,8 +84,10 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference { ...@@ -115,8 +84,10 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(isclose, IscloseInferShapeFunctor,
PD_INFER_META(phi::ValueCompareInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
isclose, ops::IscloseOp, ops::IscloseOpMaker, isclose, ops::IscloseOp, ops::IscloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::IscloseOpVarTypeInference); ops::IscloseOpVarTypeInference, IscloseInferShapeFunctor);
...@@ -862,6 +862,16 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -862,6 +862,16 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y); out->share_lod(y);
} }
void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config) {
detail::BinarySameInputDimsCheck(x, y, config);
out->set_dims(x.dims());
out->set_dtype(DataType::BOOL);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
...@@ -142,4 +142,9 @@ void TriangularSolveInferMeta(const MetaTensor& x, ...@@ -142,4 +142,9 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular, bool unitriangular,
MetaTensor* out); MetaTensor* out);
void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册