allclose_op.h 3.3 KB
Newer Older
Z
Zhen Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17 18
#include <string>
#include "paddle/fluid/framework/data_type.h"
Z
Zhen Wang 已提交
19 20 21 22 23 24 25 26
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

H
huangxu96 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
template <typename DeviceContext, typename T>
struct GetTensorValue {
  T operator()(const platform::DeviceContext& ctx,
               const framework::Tensor& tensor) const;
};

template <typename DeviceContext, typename T>
struct AllcloseFunctor {
  void operator()(const DeviceContext& ctx, const framework::Tensor& in,
                  const framework::Tensor& other, const float rtol,
                  const float atol, bool equal_nan, framework::Tensor* output);
};

Z
Zhen Wang 已提交
40 41 42 43 44 45 46
template <typename DeviceContext, typename T>
class AllcloseKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    // get attrs
    bool equal_nan = ctx.Attr<bool>("equal_nan");
    // get input/output
H
huangxu96 已提交
47 48
    const auto* input = ctx.Input<Tensor>("Input");
    const auto* other = ctx.Input<Tensor>("Other");
Z
Zhen Wang 已提交
49
    auto* out = ctx.Output<Tensor>("Out");
H
huangxu96 已提交
50

51 52 53 54
    double rtol_v = std::stod(ctx.Attr<std::string>("rtol"));
    double atol_v = std::stod(ctx.Attr<std::string>("atol"));

    auto& dev_ctx = ctx.template device_context<DeviceContext>();
H
huangxu96 已提交
55
    GetTensorValue<DeviceContext, double> get_tensor_value;
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
    if (ctx.HasInput("Rtol")) {
      const auto* rtol = ctx.Input<Tensor>("Rtol");
      PADDLE_ENFORCE_EQ(
          rtol->numel(), 1,
          platform::errors::InvalidArgument(
              "Input(Rtol) size must be 1, but get %d.", rtol->numel()));
      PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64,
                        platform::errors::InvalidArgument(
                            "Input(Rtol) type must be double, but get %s.",
                            framework::DataTypeToString(rtol->type())));
      rtol_v = get_tensor_value(dev_ctx, *rtol);
    }
    if (ctx.HasInput("Atol")) {
      const auto* atol = ctx.Input<Tensor>("Atol");
      PADDLE_ENFORCE_EQ(
          atol->numel(), 1,
          platform::errors::InvalidArgument(
              "Input(Atol) size must be 1, but get %d", atol->numel()));
      PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64,
                        platform::errors::InvalidArgument(
                            "Input(Atol) type must be double, but get %s",
                            framework::DataTypeToString(atol->type())));
      atol_v = get_tensor_value(dev_ctx, *atol);
    }

H
huangxu96 已提交
81 82
    AllcloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
                                        equal_nan, out);
Z
Zhen Wang 已提交
83 84 85 86 87
  }
};

}  // namespace operators
}  // namespace paddle