未验证 提交 c7637700 编写于 作者: C Chen Weihang 提交者: GitHub

move isclose infershape (#40595)

上级 8fd20b5b
......@@ -14,10 +14,13 @@
#include <cmath>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -60,40 +63,6 @@ class IscloseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose");
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose");
auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other");
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
platform::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = ctx->IsRuntime();
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
}
}
}
ctx->SetOutputDim("Out", input_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -115,8 +84,10 @@ class IscloseOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(isclose, IscloseInferShapeFunctor,
PD_INFER_META(phi::ValueCompareInferMeta));
REGISTER_OPERATOR(
isclose, ops::IscloseOp, ops::IscloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::IscloseOpVarTypeInference);
ops::IscloseOpVarTypeInference, IscloseInferShapeFunctor);
......@@ -862,6 +862,16 @@ void TriangularSolveInferMeta(const MetaTensor& x,
out->share_lod(y);
}
void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config) {
detail::BinarySameInputDimsCheck(x, y, config);
out->set_dims(x.dims());
out->set_dtype(DataType::BOOL);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
......@@ -142,4 +142,9 @@ void TriangularSolveInferMeta(const MetaTensor& x,
bool unitriangular,
MetaTensor* out);
void ValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config = MetaConfig());
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册