未验证 提交 5d08a447 编写于 作者: C Chen Weihang 提交者: GitHub

move allclose infershape (#40508)

上级 e157f2af
......@@ -15,10 +15,13 @@
#include <cmath>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -61,40 +64,6 @@ class AllcloseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose");
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose");
auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other");
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
platform::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = ctx->IsRuntime();
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
}
}
}
ctx->SetOutputDim("Out", phi::make_ddim({1}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -117,11 +86,13 @@ class AllcloseOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(allclose, AllcloseInferShapeFunctor,
PD_INFER_META(phi::AllValueCompareInferMeta));
REGISTER_OPERATOR(
allclose, ops::AllcloseOp, ops::AllcloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllcloseOpVarTypeInference);
ops::AllcloseOpVarTypeInference, AllcloseInferShapeFunctor);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(allclose)
......
......@@ -21,6 +21,56 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
namespace detail {
static void BinarySameInputDimsCheck(const MetaTensor& x,
const MetaTensor& y,
MetaConfig config) {
auto input_dim = x.dims();
auto other_dim = y.dims();
PADDLE_ENFORCE_EQ(input_dim.size(),
other_dim.size(),
phi::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = config.is_runtime;
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i],
other_dim[i],
phi::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i,
input_dim[i],
other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i],
other_dim[i],
phi::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i,
input_dim[i],
other_dim[i]));
}
}
}
}
} // namespace detail
void AllValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config) {
detail::BinarySameInputDimsCheck(x, y, config);
out->set_dims(phi::make_ddim({1}));
out->set_dtype(DataType::BOOL);
}
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->share_meta(x);
......
......@@ -29,6 +29,11 @@ namespace phi {
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
void AllValueCompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out,
MetaConfig config = MetaConfig());
void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);
void BCELossInferMeta(const MetaTensor& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册