allclose_op.cc 5.9 KB
Newer Older
Z
Zhen Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/allclose_op.h"
H
huangxu96 已提交
16
#include <cmath>
Z
Zhen Wang 已提交
17 18
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
H
huangxu96 已提交
19
#include "paddle/fluid/platform/enforce.h"
Z
Zhen Wang 已提交
20 21 22 23

namespace paddle {
namespace operators {

H
huangxu96 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
template <typename T>
struct GetTensorValue<platform::CPUDeviceContext, T> {
  T operator()(const platform::CPUDeviceContext& dev_ctx,
               const framework::Tensor& tensor) const {
    return *(tensor.data<T>());
  }
};

template <typename T>
struct AllcloseFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& other,
                  const double rtol, const double atol, bool equal_nan,
                  framework::Tensor* output) {
    auto* in_a = in.data<T>();
    auto* in_b = other.data<T>();
    auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
    auto num = in.numel();
    *out_data = true;
    for (int i = 0; i < num; i++) {
      const T a = in_a[i], b = in_b[i];
      bool val;
      if (std::isnan(a) || std::isnan(b)) {
        val = equal_nan && std::isnan(a) == std::isnan(b);
      } else {
        T left = (a > b ? a - b : b - a);
        T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
        T diff = (left > right ? left - right : right - left);
        val = a == b || left <= right || diff <= 1e-15;
      }
      *out_data &= val;
    }
  }
};

Z
Zhen Wang 已提交
59 60 61
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
62 63 64 65
    AddInput("Input",
             "The input tensor, it's data type should be float32, float64.");
    AddInput("Other",
             "The input tensor, it's data type should be float32, float64.");
H
huangxu96 已提交
66 67
    AddInput("Rtol", "The relative tolerance.");
    AddInput("Atol", "The absolute tolerance.");
68
    AddOutput("Out", "The output tensor, it's data type is bool.");
Z
Zhen Wang 已提交
69 70 71 72 73 74
    AddAttr<bool>("equal_nan",
                  "If :math:`True` , then two :math:`NaNs` will be "
                  "compared as equal. Default: :math:`False` .")
        .SetDefault(false);

    AddComment(R"DOC( 
75
This operator checks if all :math:`x` and :math:`y` satisfy the condition:
Z
Zhen Wang 已提交
76

77 78
.. math::
    \left| x - y \right| \leq atol + rtol \times \left| y \right|
Z
Zhen Wang 已提交
79

80
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this
Z
Zhen Wang 已提交
81 82 83 84 85 86 87 88 89 90
operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if
two tensors are elementwise equal within a tolerance.
)DOC");
  }
};

class AllcloseOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

H
huangxu96 已提交
91 92 93 94 95 96
  void InferShape(framework::InferShapeContext* ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose");
    OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose");
    OP_INOUT_CHECK(ctx->HasInput("Rtol"), "Input", "Rtol", "Allclose");
    OP_INOUT_CHECK(ctx->HasInput("Atol"), "Input", "Atol", "Allclose");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose");
Z
Zhen Wang 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

    auto input_dim = ctx->GetInputDim("Input");
    auto other_dim = ctx->GetInputDim("Other");
    PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
                      platform::errors::PreconditionNotMet(
                          "Input(Input) and Input(Other) must have the same "
                          "dimension size."));
    int n = input_dim.size();
    bool is_runtime = ctx->IsRuntime();
    for (int i = 0; i < n; i++) {
      if (is_runtime) {
        PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
                          platform::errors::PreconditionNotMet(
                              "The value at dim %d of Input(Input) is not "
                              "equal to the Input(Other): %ld != %ld.",
                              i, input_dim[i], other_dim[i]));
      } else {
        if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
          PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
                            platform::errors::PreconditionNotMet(
                                "The value at dim %d of Input(Input) is not "
                                "equal to the Input(Other): %ld != %ld.",
                                i, input_dim[i], other_dim[i]));
        }
      }
    }

    ctx->SetOutputDim("Out", framework::make_ddim({1}));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
H
huangxu96 已提交
129
      const framework::ExecutionContext& ctx) const override {
Z
Zhen Wang 已提交
130 131 132 133 134 135 136 137
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
        ctx.device_context());
  }
};

class AllcloseOpVarTypeInference : public framework::VarTypeInference {
 public:
H
huangxu96 已提交
138
  void operator()(framework::InferVarTypeContext* ctx) const override {
139
    ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
Z
Zhen Wang 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;

REGISTER_OPERATOR(
    allclose, ops::AllcloseOp, ops::AllcloseOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    ops::AllcloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
                       ops::AllcloseKernel<CPU, double>);